diff --git a/example/common/tokenizer.cc b/example/common/tokenizer.cc index d330753f..9541454a 100644 --- a/example/common/tokenizer.cc +++ b/example/common/tokenizer.cc @@ -10,6 +10,9 @@ #include "glog/logging.h" #include "example/common/utils.h" +#include "infini_train/include/nn/functional.h" +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/tensor.h" namespace infini_train { @@ -103,7 +106,7 @@ std::string Tokenizer::Decode(uint32_t token_id) const { } void Tokenizer::GenerateText(infini_train::nn::Module &model, uint32_t batch_size, uint32_t sequence_length, - uint32_t text_length, const Device *device) const { + uint32_t text_length, Device device) const { std::vector dims; dims.assign({batch_size, sequence_length}); // x_tensor (FLAGS_batch_size, FLAGS_sequence_length) eq:(4, 64) @@ -121,7 +124,7 @@ void Tokenizer::GenerateText(infini_train::nn::Module &model, uint32_t batch_siz uint64_t kRngState = kRngState; LOG(INFO) << "start generate text:"; - const auto *cpu_device = DeviceManager::Instance()->GetDefaultDevice(); + auto cpu_device = Device(); for (int t = prompt_len; t < text_length; ++t) { x = std::make_shared(x->To(device)); // CPU->calc device // TODO(jym): use no_grad forward later diff --git a/example/common/tokenizer.h b/example/common/tokenizer.h index af42dd24..c9d0b76c 100644 --- a/example/common/tokenizer.h +++ b/example/common/tokenizer.h @@ -1,15 +1,13 @@ #include #include -#include #include #include "infini_train/include/device.h" -#include "infini_train/include/nn/functional.h" -#include "infini_train/include/nn/modules/module.h" -#include "infini_train/include/tensor.h" namespace infini_train { - +namespace nn { +class Module; +} class Tokenizer { public: enum class Version : uint32_t { @@ -22,7 +20,7 @@ class Tokenizer { std::string Decode(uint32_t token_id) const; void GenerateText(infini_train::nn::Module &model, uint32_t batch_size, uint32_t sequence_length, - uint32_t text_length, const Device *device) const; + uint32_t text_length, Device device) const; uint32_t GetEndToken() const { return eot_token_; }; diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 4a34c464..93fd361f 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -104,7 +104,7 @@ void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; // select the device - const Device *device; + Device device; int ddp_world_size = global::GetDataParallelSize(); int tp_world_size = global::GetTensorParallelSize(); @@ -125,7 +125,7 @@ void Train(const nn::parallel::Rank &rank) { const ProcessGroup *pp_pg = nullptr; if (rank.IsParallel()) { - device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank()); + device = Device(Device::DeviceType::kCUDA, rank.thread_rank()); if (ddp_world_size > 1) { ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()), @@ -149,8 +149,7 @@ void Train(const nn::parallel::Rank &rank) { nn::parallel::pp_rank = pp_rank; } } else { - device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice() - : DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0); + device = FLAGS_device == kDeviceCPU ? Device() : Device(Device::DeviceType::kCUDA, 0); } // calculate gradient accumulation from the desired total batch size and the current run configuration @@ -201,7 +200,7 @@ void Train(const nn::parallel::Rank &rank) { model = std::make_shared( model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared(optimizer), - rank.thread_rank(), std::dynamic_pointer_cast(model)->GetChunkSize()); + device, std::dynamic_pointer_cast(model)->GetChunkSize()); if (ddp_world_size > 1) { auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { @@ -285,7 +284,7 @@ void Train(const nn::parallel::Rank &rank) { for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) { // enable autocast for the current step - infini_train::AutocastGuard autocast_guard(device->Type(), dtype); + infini_train::AutocastGuard autocast_guard(device.type(), dtype); // (bs, seq_len), (bs, seq_len) auto [x, y] = *train_iter; @@ -308,7 +307,7 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish loss forward"; - auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); + auto loss_cpu = loss->To(Device()); lossf += static_cast(loss_cpu.DataPtr())[0]; LOG(INFO) << "Rank " << rank.GlobalRank() << ": start backward"; loss->Backward(); @@ -330,8 +329,7 @@ void Train(const nn::parallel::Rank &rank) { if (ddp_world_size > 1) { auto lossf_tensor = std::make_shared(&lossf, std::vector{}, DataType::kFLOAT32, device); function::AllReduce(lossf_tensor, function::ReduceOpType::kAvg, ddp_pg); - lossf = static_cast( - lossf_tensor->To(DeviceManager::Instance()->GetDefaultDevice()).DataPtr())[0]; + lossf = static_cast(lossf_tensor->To(Device()).DataPtr())[0]; } const auto iter_end = std::chrono::high_resolution_clock::now(); diff --git a/example/llama3/main.cc b/example/llama3/main.cc index fdea2162..7e137abb 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -86,7 +86,7 @@ void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; // select the device - const Device *device; + Device device; int ddp_world_size = global::GetDataParallelSize(); int tp_world_size = global::GetTensorParallelSize(); @@ -107,7 +107,7 @@ void Train(const nn::parallel::Rank &rank) { const ProcessGroup *pp_pg = nullptr; if (rank.IsParallel()) { - device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank()); + device = Device(Device::DeviceType::kCUDA, rank.thread_rank()); if (ddp_world_size > 1) { ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()), @@ -131,8 +131,7 @@ void Train(const nn::parallel::Rank &rank) { nn::parallel::pp_rank = pp_rank; } } else { - device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice() - : DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0); + device = FLAGS_device == kDeviceCPU ? Device() : Device(Device::DeviceType::kCUDA, 0); } // calculate gradient accumulation from the desired total batch size and the current run configuration @@ -181,7 +180,7 @@ void Train(const nn::parallel::Rank &rank) { model = std::make_shared( model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared(optimizer), - rank.thread_rank(), std::dynamic_pointer_cast(model)->GetChunkSize()); + device, std::dynamic_pointer_cast(model)->GetChunkSize()); if (ddp_world_size > 1) { auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { @@ -261,7 +260,7 @@ void Train(const nn::parallel::Rank &rank) { for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) { // enable autocast for the current step - infini_train::AutocastGuard autocast_guard(device->Type(), dtype); + infini_train::AutocastGuard autocast_guard(device.type(), dtype); // (bs, seq_len), (bs, seq_len) auto [x, y] = *train_iter; @@ -284,7 +283,7 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish loss forward"; - auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); + auto loss_cpu = loss->To(Device()); lossf += static_cast(loss_cpu.DataPtr())[0]; LOG(INFO) << "Rank " << rank.GlobalRank() << ": start backward"; loss->Backward(); @@ -306,8 +305,7 @@ void Train(const nn::parallel::Rank &rank) { if (ddp_world_size > 1) { auto lossf_tensor = std::make_shared(&lossf, std::vector{}, DataType::kFLOAT32, device); function::AllReduce(lossf_tensor, function::ReduceOpType::kAvg, ddp_pg); - lossf = static_cast( - lossf_tensor->To(DeviceManager::Instance()->GetDefaultDevice()).DataPtr())[0]; + lossf = static_cast(lossf_tensor->To(Device()).DataPtr())[0]; } const auto iter_end = std::chrono::high_resolution_clock::now(); diff --git a/example/llama3/net.cc b/example/llama3/net.cc index a70a811a..8754fe3a 100644 --- a/example/llama3/net.cc +++ b/example/llama3/net.cc @@ -100,8 +100,7 @@ ApplyRotaryEmbedding(const std::shared_ptr &xq, const std::shared_ptr PrecomputeFreqsCis(int64_t dim, int64_t end, float theta = 10000.0f, bool use_scaled = false, - const infini_train::Device *device - = DeviceManager::Instance()->GetDefaultDevice()) { + infini_train::Device device = Device()) { DataType dtype = DataType::kFLOAT32; CHECK_GE(dim, 2) << "dim must be >= 2 for slicing"; auto arange = nn::init::Arange(0, dim, dtype, device)->Slice(0, 0, dim, 2); @@ -127,7 +126,7 @@ std::vector> SwiGLU::Forward(const std::vector(std::vector{dim}, DataType::kFLOAT32, device)->RequiresGrad(); nn::init::Ones(parameters_[kParamWeightName]); diff --git a/example/llama3/net.h b/example/llama3/net.h index 9bd7f9da..88604daf 100644 --- a/example/llama3/net.h +++ b/example/llama3/net.h @@ -51,8 +51,7 @@ class RMSNorm : public infini_train::nn::CloneableModule { public: static constexpr char kParamWeightName[] = "weight"; - explicit RMSNorm(int64_t dim, float eps = 1e-6f, - const infini_train::Device *device = infini_train::DeviceManager::Instance()->GetDefaultDevice()); + explicit RMSNorm(int64_t dim, float eps = 1e-6f, infini_train::Device device = infini_train::Device()); std::vector> Forward(const std::vector> &x) override; diff --git a/example/mnist/main.cc b/example/mnist/main.cc index 097529bf..e62257d7 100644 --- a/example/mnist/main.cc +++ b/example/mnist/main.cc @@ -48,9 +48,8 @@ int main(int argc, char *argv[]) { DataLoader test_dataloader(test_dataset, FLAGS_bs); auto network = MNIST(); - const Device *device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDevice(DeviceType::kCPU, 0) - : DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0); - const Device *cpu_device = DeviceManager::Instance()->GetDefaultDevice(); + Device device = FLAGS_device == kDeviceCPU ? Device() : Device(Device::DeviceType::kCUDA, 0); + Device cpu_device = Device(); network.To(device); auto loss_fn = nn::CrossEntropyLoss(); diff --git a/infini_train/include/autocast.h b/infini_train/include/autocast.h index e5bcf6af..499c586f 100644 --- a/infini_train/include/autocast.h +++ b/infini_train/include/autocast.h @@ -3,15 +3,10 @@ #include #include -#include "common/common.h" -#include "datatype.h" -#include "device.h" -#include "tensor.h" - -#ifdef USE_CUDA -#include -#include -#endif +#include "infini_train/include/common/common.h" +#include "infini_train/include/datatype.h" +#include "infini_train/include/device.h" +#include "infini_train/include/tensor.h" namespace infini_train { namespace { @@ -91,18 +86,18 @@ inline const std::unordered_map kOpCastPolicyMap = }; // Default autocast data types for each device type -inline constexpr std::array(DeviceType::kCount)> kDeviceDefaultDtype = { +inline constexpr std::array(Device::DeviceType::kCount)> kDeviceDefaultDtype = { DataType::kBFLOAT16, // CPU DataType::kFLOAT16, // CUDA. }; // Thread-local context to track autocast state struct AutocastContext { - bool enabled = false; // Whether autocast is active in the current thread - DeviceType device_type = DeviceType::kCPU; // Target device type (CPU/GPU) - DataType autocast_dtype = DataType::kBFLOAT16; // The data type used for autocasting + bool enabled = false; // Whether autocast is active in the current thread + Device::DeviceType device_type = Device::DeviceType::kCPU; // Target device type (CPU/GPU) + DataType autocast_dtype = DataType::kBFLOAT16; // The data type used for autocasting - template void Autocast(std::pair key, ArgsT &...args) { + template void Autocast(std::pair key, ArgsT &...args) { if (!enabled) { return; } @@ -172,14 +167,14 @@ inline thread_local AutocastContext tls_autocast_context; // RAII guard to enable/disable autocast in a scope class AutocastGuard { public: - AutocastGuard(DeviceType device_type, DataType autocast_dtype) { + AutocastGuard(Device::DeviceType device_type, DataType autocast_dtype) { saved_context_ = tls_autocast_context; tls_autocast_context.enabled = true; tls_autocast_context.device_type = device_type; tls_autocast_context.autocast_dtype = autocast_dtype; } - AutocastGuard(DeviceType device_type) + AutocastGuard(Device::DeviceType device_type) : AutocastGuard(device_type, kDeviceDefaultDtype[static_cast(device_type)]) {} // Disable autocast (restore previous state) diff --git a/infini_train/include/autograd/comm.h b/infini_train/include/autograd/comm.h index b54c814f..e74c821d 100644 --- a/infini_train/include/autograd/comm.h +++ b/infini_train/include/autograd/comm.h @@ -4,6 +4,7 @@ #include #include "infini_train/include/autograd/function.h" +#include "infini_train/include/device.h" namespace infini_train { class Tensor; @@ -19,7 +20,7 @@ class Scatter : public autograd::Function { public: static constexpr char kType[] = "ScatterFunction"; - explicit Scatter(const std::vector &target_gpus, int64_t dim, + explicit Scatter(const std::vector &target_gpus, int64_t dim, const infini_train::nn::parallel::ProcessGroup *pg = nullptr); std::vector> Forward(const std::vector> &input_tensors) override; @@ -31,8 +32,8 @@ class Scatter : public autograd::Function { private: const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr; - std::vector target_gpus_; - const Device *input_device_ = nullptr; + std::vector target_gpus_; + Device input_device_ = Device(); int64_t dim_ = 0; }; @@ -40,8 +41,7 @@ class Gather : public autograd::Function { public: static constexpr char kType[] = "GatherFunction"; - explicit Gather(const Device *target_device, int64_t dim, - const infini_train::nn::parallel::ProcessGroup *pg = nullptr); + explicit Gather(Device target_device, int64_t dim, const infini_train::nn::parallel::ProcessGroup *pg = nullptr); std::vector> Forward(const std::vector> &input_tensors) override; @@ -52,8 +52,8 @@ class Gather : public autograd::Function { private: const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr; - const Device *target_device_ = nullptr; - std::vector input_gpus_; + Device target_device_ = Device(); + std::vector input_gpus_; int64_t dim_ = 0; bool unsqueezed_scalar_ = false; }; @@ -62,7 +62,7 @@ class Broadcast : public autograd::Function { public: static constexpr char kType[] = "BroadcastFunction"; - explicit Broadcast(const std::vector &target_gpus, + explicit Broadcast(const std::vector &target_gpus, const infini_train::nn::parallel::ProcessGroup *pg = nullptr); std::vector> Forward(const std::vector> &input_tensors) override; @@ -74,16 +74,16 @@ class Broadcast : public autograd::Function { private: const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr; - std::vector target_gpus_; + std::vector target_gpus_; int64_t num_inputs_ = 0; - const Device *input_device_ = nullptr; + Device input_device_ = Device(); }; class ReduceAddCoalesced : public autograd::Function { public: static constexpr char kType[] = "ReduceAddCoalescedFunction"; - explicit ReduceAddCoalesced(const Device *destination, int64_t num_inputs, + explicit ReduceAddCoalesced(Device destination, int64_t num_inputs, const infini_train::nn::parallel::ProcessGroup *pg = nullptr); std::vector> Forward(const std::vector> &input_tensors) override; @@ -95,8 +95,8 @@ class ReduceAddCoalesced : public autograd::Function { private: const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr; - const Device *destination_ = nullptr; - std::vector target_gpus_; + Device destination_ = Device(); + std::vector target_gpus_; int64_t num_inputs_ = 0; }; } // namespace infini_train::autograd diff --git a/infini_train/include/common/common.h b/infini_train/include/common/common.h index 9d726d35..b6a02543 100644 --- a/infini_train/include/common/common.h +++ b/infini_train/include/common/common.h @@ -1,5 +1,8 @@ #pragma once +#include +#include + #include "glog/logging.h" #include "infini_train/include/datatype.h" diff --git a/infini_train/include/core/blas_handle.h b/infini_train/include/core/blas_handle.h new file mode 100644 index 00000000..56b058ce --- /dev/null +++ b/infini_train/include/core/blas_handle.h @@ -0,0 +1,11 @@ +#pragma once + +namespace infini_train::core { + +class BlasHandle { +public: + BlasHandle(){}; + virtual ~BlasHandle() = default; +}; + +} // namespace infini_train::core diff --git a/infini_train/include/core/device_guard.h b/infini_train/include/core/device_guard.h new file mode 100644 index 00000000..0ff98d33 --- /dev/null +++ b/infini_train/include/core/device_guard.h @@ -0,0 +1,185 @@ +#pragma once + +#include +#include +#include + +#include "infini_train/include/device.h" + +namespace infini_train::core { + +class Stream; +class BlasHandle; + +enum class MemcpyKind : int8_t { + kH2D = 0, + kD2H = 1, + kD2D = 2, + kInvalid = -1, +}; + +// +// ---------------------------------------------------------------------------- +// DeviceGuardImpl: Backend-specific device/stream/memory/BLAS implementation +// ---------------------------------------------------------------------------- +// This is the low-level virtual interface that each backend must implement. +// Examples: +// - CUDA: CudaDeviceGuardImpl +// - CPU: CpuDeviceGuardImpl +// - Custom: MyChipDeviceGuardImpl +// +// DeviceGuardImpl encapsulates **all device-runtime behaviors**, including: +// +// • Querying / setting the current device +// • Stream creation/lookup +// • Synchronization primitives +// • Memory allocation & copy +// • Access to BLAS handles +// +// DeviceGuard (the public RAII wrapper) forwards calls to the DeviceGuardImpl +// instance registered for the device type. +// +// TODO(zbl): add event managemnt +// +class DeviceGuardImpl { +public: + DeviceGuardImpl() {} + + virtual ~DeviceGuardImpl() = default; + + // ---------------------------------------------------------------------- + // Device management + // ---------------------------------------------------------------------- + + virtual Device GetDevice() const = 0; + + virtual void SetDevice(Device device) const {} + + virtual int8_t DeviceCount() const; + + virtual Device::DeviceType Type() const = 0; + + // ---------------------------------------------------------------------- + // Stream management + // ---------------------------------------------------------------------- + + virtual Stream *GetStream(Device) const; + + // ---------------------------------------------------------------------- + // Synchronization + // ---------------------------------------------------------------------- + + virtual void SynchronizeDevice(Device) const {} + + virtual void SynchronizeStream(Stream *) const {} + + // ---------------------------------------------------------------------- + // BLAS handle + // ---------------------------------------------------------------------- + + virtual BlasHandle *GetBlasHandle(Device) const; + + // ---------------------------------------------------------------------- + // Memory operations + // ---------------------------------------------------------------------- + + virtual void Malloc(void **dev_ptr, size_t size) = 0; + + virtual void MallocAsync(void **dev_ptr, size_t size, Stream *stream); + + virtual void Free(void *dev_ptr) = 0; + + virtual void FreeAsync(void *dev_ptr, Stream *stream); + + virtual void Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) = 0; + + virtual void MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream); +}; + +// +// ---------------------------------------------------------------------------- +// DeviceGuard: RAII front-end wrapper for DeviceGuardImpl +// ---------------------------------------------------------------------------- +// This class is the **public-facing device interface** for the framework. +// It automatically: +// +// • Saves the current device on construction +// • Switches to the target device +// • Restores the previous device on destruction +// +// All runtime operations are forwarded to the backend-specific DeviceGuardImpl +// instance registered for that device type. +// +class DeviceGuard { +public: + explicit DeviceGuard(Device device); + + ~DeviceGuard(); + + // Copy is disallowed + DeviceGuard(const DeviceGuard &) = delete; + DeviceGuard &operator=(const DeviceGuard &) = delete; + + // Move is disallowed, as DeviceGuard does not have an uninitialized state, + // which is required for moves on types with nontrival destructors. + DeviceGuard(DeviceGuard &&other) = delete; + DeviceGuard &operator=(DeviceGuard &&other) = delete; + + void SetDevice(Device device); + + Device current_device() const; + + Device original_device() const; + +private: + DeviceGuardImpl *impl_ = nullptr; + Device original_device_; + Device current_device_; +}; + +// +// ---------------------------------------------------------------------------- +// DeviceGuardImplRegistry: Global registry of backend implementations +// ---------------------------------------------------------------------------- +// This registry stores at most one DeviceGuardImpl per DeviceType. +// Backends register themselves at static initialization time via the macro +// INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(). +// +// Example: +// INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(Device::DeviceType::kCUDA, CudaGuardImpl) +// +class DeviceGuardImplRegistry { +public: + static DeviceGuardImplRegistry &Instance(); + + void Register(Device::DeviceType type, std::unique_ptr impl); + + DeviceGuardImpl *Get(Device::DeviceType type) const; + +private: + DeviceGuardImplRegistry() = default; + DeviceGuardImplRegistry(const DeviceGuardImplRegistry &) = delete; + DeviceGuardImplRegistry &operator=(const DeviceGuardImplRegistry &) = delete; + + std::unordered_map> impls_; +}; + +DeviceGuardImpl *GetDeviceGuardImpl(Device::DeviceType type); + +} // namespace infini_train::core + +// +// ---------------------------------------------------------------------------- +// Registration macro +// ---------------------------------------------------------------------------- +// Registers a DeviceGuardImpl implementation into the global registry +// at static initialization time. +// +// Example usage: +// INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(Device::DeviceType::kCUDA, CudaGuardImpl) +// +#define INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(device_type, class_impl) \ + static const bool __infini_train_device_guard_registered##__COUNTER__ = []() { \ + infini_train::core::DeviceGuardImplRegistry::Instance().Register(device_type, std::make_unique()); \ + return true; \ + }(); diff --git a/infini_train/include/core/stream.h b/infini_train/include/core/stream.h new file mode 100644 index 00000000..190298f6 --- /dev/null +++ b/infini_train/include/core/stream.h @@ -0,0 +1,10 @@ +#pragma once + +namespace infini_train::core { + +class Stream { +public: + virtual ~Stream() = default; +}; + +} // namespace infini_train::core diff --git a/infini_train/include/device.h b/infini_train/include/device.h index 36357a09..28db395f 100644 --- a/infini_train/include/device.h +++ b/infini_train/include/device.h @@ -1,101 +1,49 @@ #pragma once #include -#include -#include -#include - -#ifdef USE_CUDA -#include -#endif - -#include "glog/logging.h" +#include +#include #include "infini_train/include/nn/parallel/rank.h" namespace infini_train { -enum class DeviceType : int8_t { - kCPU = 0, - kCUDA = 1, - kCount = 2, -}; - -class DeviceManager; - class Device { public: - virtual ~Device() = default; - - DeviceType Type() const; - int8_t Index() const; - - bool IsCPU() const; - bool IsCUDA() const; + enum class DeviceType : int8_t { + kCPU = 0, + kCUDA = 1, + kCount = 2, + kInvalid = -1, + }; - virtual void SetDevice() const {} - virtual void Synchronize() const {} + Device(); - std::string ToString() const; - - virtual nn::parallel::Rank rank() const; - - friend std::ostream &operator<<(std::ostream &os, const Device &device); - -protected: Device(DeviceType type, int8_t index); - DeviceType type_; - int8_t index_; -}; + Device &operator=(const Device &) = default; -class CpuDevice : public Device { -private: - CpuDevice(); + ~Device() = default; - friend class DeviceManager; -}; + DeviceType type() const; + int8_t index() const; -#ifdef USE_CUDA -class CudaDevice : public Device { -public: - ~CudaDevice() override; - - void SetDevice() const override; - void Synchronize() const override; - - cudaStream_t Stream() const; - - cublasHandle_t CublasHandle() const; - - nn::parallel::Rank rank() const override; - -private: - CudaDevice(int8_t index); - - cudaStream_t stream_ = nullptr; - - cublasHandle_t cublas_handle_ = nullptr; - - nn::parallel::Rank rank_; + bool IsCPU() const; + bool IsCUDA() const; - friend class DeviceManager; -}; -#endif + std::string ToString() const; -class DeviceManager { -public: - static const DeviceManager *Instance(); + virtual nn::parallel::Rank Rank() const; - const Device *GetDevice(DeviceType type, int8_t index = 0) const; + friend std::ostream &operator<<(std::ostream &os, const Device &device); - const Device *GetDefaultDevice() const; + friend bool operator==(const Device &a, const Device &b); - std::vector GetAllAvailableDevices(DeviceType device_type) const; + friend bool operator!=(const Device &a, const Device &b); private: - DeviceManager(); - - std::unordered_map>> devices_map_; + DeviceType type_ = DeviceType::kInvalid; + int8_t index_ = -1; }; + } // namespace infini_train diff --git a/infini_train/include/dispatcher.h b/infini_train/include/dispatcher.h index 7b87d59a..fc95d64e 100644 --- a/infini_train/include/dispatcher.h +++ b/infini_train/include/dispatcher.h @@ -260,10 +260,8 @@ auto DispatchFunc(DataType dtype, Functor &&func, std::string_view context_ident CASE_FOR_TYPE(DataType::kINT64) CASE_FOR_TYPE(DataType::kFLOAT32) CASE_FOR_TYPE(DataType::kFLOAT64) -#ifdef USE_CUDA CASE_FOR_TYPE(DataType::kBFLOAT16) CASE_FOR_TYPE(DataType::kFLOAT16) -#endif #undef CASE_FOR_TYPE } LOG_UNSUPPORTED_DTYPE(dtype, context_identifier); @@ -328,10 +326,8 @@ template st CASE_FOR_TYPE(DataType::kINT64) CASE_FOR_TYPE(DataType::kFLOAT32) CASE_FOR_TYPE(DataType::kFLOAT64) -#ifdef USE_CUDA CASE_FOR_TYPE(DataType::kBFLOAT16) CASE_FOR_TYPE(DataType::kFLOAT16) -#endif #undef CASE_FOR_TYPE } LOG_UNSUPPORTED_DTYPE(dtype, context_identifier); @@ -413,7 +409,7 @@ class KernelFunction { class Dispatcher { public: - using KeyT = std::pair; + using KeyT = std::pair; static Dispatcher &Instance() { static Dispatcher instance; diff --git a/infini_train/include/nn/init.h b/infini_train/include/nn/init.h index 644df590..fc6effec 100644 --- a/infini_train/include/nn/init.h +++ b/infini_train/include/nn/init.h @@ -6,6 +6,7 @@ #include #include "infini_train/include/datatype.h" +#include "infini_train/include/device.h" namespace infini_train { class Tensor; @@ -50,5 +51,5 @@ std::shared_ptr Ones(const std::shared_ptr &tensor); std::shared_ptr Zeros(const std::shared_ptr &tensor); -std::shared_ptr Arange(int64_t start, int64_t end, DataType dtype, const Device *device = nullptr); +std::shared_ptr Arange(int64_t start, int64_t end, DataType dtype, Device device = Device()); } // namespace infini_train::nn::init diff --git a/infini_train/include/nn/modules/linear.h b/infini_train/include/nn/modules/linear.h index e02b91a6..c4103df6 100644 --- a/infini_train/include/nn/modules/linear.h +++ b/infini_train/include/nn/modules/linear.h @@ -3,6 +3,7 @@ #include #include +#include "infini_train/include/device.h" #include "infini_train/include/nn/modules/module.h" namespace infini_train { @@ -18,7 +19,7 @@ class Linear : public CloneableModule { static constexpr char kParamWeightName[] = "weight"; static constexpr char kParamBiasName[] = "bias"; - Linear(int64_t in_features, int64_t out_features, bool bias = true, const Device *device = nullptr); + Linear(int64_t in_features, int64_t out_features, bool bias = true, Device device = Device()); std::vector> Forward(const std::vector> &input_tensors) override; private: diff --git a/infini_train/include/nn/modules/module.h b/infini_train/include/nn/modules/module.h index 9bc78bcc..c1e238b1 100644 --- a/infini_train/include/nn/modules/module.h +++ b/infini_train/include/nn/modules/module.h @@ -7,10 +7,10 @@ #include #include "infini_train/include/datatype.h" +#include "infini_train/include/device.h" namespace infini_train { class Tensor; -class Device; } // namespace infini_train namespace infini_train::nn { @@ -18,7 +18,7 @@ class Module; namespace parallel::function { std::vector> Replicate(const std::shared_ptr &network, - const std::vector &devices); + const std::vector &devices); } // namespace parallel::function class Module : public std::enable_shared_from_this { @@ -58,7 +58,7 @@ class Module : public std::enable_shared_from_this { return 0.0f; }; - virtual void To(const Device *device); + virtual void To(Device device); virtual void To(DataType dtype); @@ -67,7 +67,7 @@ class Module : public std::enable_shared_from_this { virtual std::shared_ptr ReplicateForDataParallel(int device_idx) const; protected: - const Device *device_ = nullptr; + Device device_; const std::string type_ = kUndefinedType; std::unordered_map> modules_; std::unordered_map> parameters_; @@ -78,8 +78,8 @@ class Module : public std::enable_shared_from_this { NamedModules(const std::string &prefix = "", bool remove_duplicate = true, std::unordered_set *memory = nullptr); - friend std::vector> - parallel::function::Replicate(const std::shared_ptr &network, const std::vector &devices); + friend std::vector> parallel::function::Replicate(const std::shared_ptr &network, + const std::vector &devices); }; template class CloneableModule : public Module { diff --git a/infini_train/include/nn/modules/normalization.h b/infini_train/include/nn/modules/normalization.h index 4dcdf807..926a58f8 100644 --- a/infini_train/include/nn/modules/normalization.h +++ b/infini_train/include/nn/modules/normalization.h @@ -3,11 +3,11 @@ #include #include +#include "infini_train/include/device.h" #include "infini_train/include/nn/modules/module.h" namespace infini_train { class Tensor; -class Device; } // namespace infini_train namespace infini_train::nn { @@ -16,7 +16,7 @@ class LayerNorm : public CloneableModule { static constexpr char kParamWeightName[] = "weight"; static constexpr char kParamBiasName[] = "bias"; - LayerNorm(const std::vector &normalized_shape, float eps = 1e-5f, const Device *device = nullptr); + LayerNorm(const std::vector &normalized_shape, float eps = 1e-5f, Device device = Device()); std::vector> Forward(const std::vector> &input_tensors) override; private: diff --git a/infini_train/include/nn/modules/sparse.h b/infini_train/include/nn/modules/sparse.h index e0605a6e..51975160 100644 --- a/infini_train/include/nn/modules/sparse.h +++ b/infini_train/include/nn/modules/sparse.h @@ -3,6 +3,7 @@ #include #include +#include "infini_train/include/device.h" #include "infini_train/include/nn/modules/module.h" namespace infini_train { @@ -17,7 +18,7 @@ class Embedding : public CloneableModule { static constexpr char kParamWeightName[] = "weight"; - Embedding(int num_embeddings, int embedding_dim, const Device *device = nullptr); + Embedding(int num_embeddings, int embedding_dim, Device device = Device()); std::vector> Forward(const std::vector> &input_tensors) override; private: diff --git a/infini_train/include/nn/parallel/data_parallel.h b/infini_train/include/nn/parallel/data_parallel.h index 581d6c3b..7d97f282 100644 --- a/infini_train/include/nn/parallel/data_parallel.h +++ b/infini_train/include/nn/parallel/data_parallel.h @@ -3,11 +3,11 @@ #include #include +#include "infini_train/include/device.h" #include "infini_train/include/nn/modules/module.h" namespace infini_train { class Tensor; -class Device; } // namespace infini_train namespace infini_train::nn::parallel { @@ -19,8 +19,8 @@ class DataParallel : public Module { private: int dim_ = 0; - std::vector devices_; - const Device *output_device_ = nullptr; - const Device *src_device_ = nullptr; + std::vector devices_; + Device output_device_ = Device(); + Device src_device_ = Device(); }; } // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/parallel_functional.h b/infini_train/include/nn/parallel/parallel_functional.h index f2559e2d..2eed56f4 100644 --- a/infini_train/include/nn/parallel/parallel_functional.h +++ b/infini_train/include/nn/parallel/parallel_functional.h @@ -3,12 +3,12 @@ #include #include +#include "infini_train/include/device.h" #include "infini_train/include/nn/parallel/process_group.h" #include "infini_train/include/nn/parallel/reduce_op_type.h" namespace infini_train { class Tensor; -class Device; namespace nn { class Module; } @@ -26,16 +26,15 @@ std::shared_ptr ReduceScatter(const std::shared_ptr &output, const ReduceOpType reduce_op, const ProcessGroup *pg = nullptr, bool async_op = false); std::vector>> Scatter(const std::vector> &input_tensors, - const std::vector &device_ids, int dim); + const std::vector &device_ids, int dim); std::vector> Gather(const std::vector>> &outputs, - const Device *target_device, int dim); + Device target_device, int dim); std::vector>> -BroadcastCoalescedReshape(const std::vector> &tensors, - const std::vector &devices); +BroadcastCoalescedReshape(const std::vector> &tensors, const std::vector &devices); std::vector> Replicate(const std::shared_ptr &network, - const std::vector &devices); + const std::vector &devices); } // namespace infini_train::nn::parallel::function diff --git a/infini_train/include/nn/parallel/pp/pipeline_stage.h b/infini_train/include/nn/parallel/pp/pipeline_stage.h index 52a776ea..a45529b4 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_stage.h +++ b/infini_train/include/nn/parallel/pp/pipeline_stage.h @@ -3,6 +3,8 @@ #include #include +#include "infini_train/include/device.h" + namespace infini_train { class Tensor; class Device; @@ -17,7 +19,7 @@ namespace infini_train::nn::parallel { class PipelineStage { public: PipelineStage(int stage_index, int num_stages, const std::vector> &recv_shape, - std::shared_ptr optimizer, int device_id, std::vector> &&chunks); + std::shared_ptr optimizer, Device device, std::vector> &&chunks); std::vector> ForwardOneChunk(const std::vector> &inputs, int local_chunk_idx = 0); @@ -30,7 +32,7 @@ class PipelineStage { int next_rank() const; int num_stages() const; - const Device *device() const; + Device device() const; const std::vector> &recv_shape() const; std::shared_ptr optimizer(); const std::vector> &chunks(); @@ -41,7 +43,7 @@ class PipelineStage { int num_stages_ = -1; int prev_rank_ = -1; int next_rank_ = -1; - const Device *device_ = nullptr; + Device device_ = Device(); std::vector> chunks_; std::shared_ptr optimizer_ = nullptr; std::vector> recv_shape_; diff --git a/infini_train/include/nn/parallel/pp/send_recv.h b/infini_train/include/nn/parallel/pp/send_recv.h index f76f4c72..4f8687ab 100644 --- a/infini_train/include/nn/parallel/pp/send_recv.h +++ b/infini_train/include/nn/parallel/pp/send_recv.h @@ -3,17 +3,18 @@ #include #include +#include "infini_train/include/device.h" + namespace infini_train { class Tensor; -class Device; } // namespace infini_train namespace infini_train::nn::parallel { std::vector> ISend(const std::vector> &input_tensors, - const Device *target_device, int cur_rank, int peer_rank, + Device target_device, int cur_rank, int peer_rank, const std::vector> &shape); -std::vector> IRecv(const std::vector> &outputs, - const Device *src_device, int cur_rank, int peer_rank); +std::vector> IRecv(const std::vector> &outputs, Device src_device, + int cur_rank, int peer_rank); } // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index d739f67d..79d84478 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -12,11 +12,11 @@ #include #endif +#include "infini_train/include/device.h" #include "infini_train/include/nn/parallel/reduce_op_type.h" namespace infini_train { class Tensor; -class Device; namespace nn { class Module; namespace parallel { @@ -62,21 +62,20 @@ class ProcessGroup { BroadCast(const std::vector> &input_tensors) const = 0; virtual std::vector> - ReduceAddCoalesced(const std::vector>> &grads, const Device *destination) const - = 0; + ReduceAddCoalesced(const std::vector>> &grads, Device destination) const = 0; virtual std::vector> Scatter(const std::shared_ptr &tensor, - std::vector devices, int64_t dim) const + std::vector devices, int64_t dim) const = 0; - virtual std::shared_ptr Gather(const std::vector> &tensors, - const Device *destination, int64_t dim) const + virtual std::shared_ptr Gather(const std::vector> &tensors, Device destination, + int64_t dim) const = 0; protected: ProcessGroup(int world_size, const std::string &name); - std::vector devices_; + std::vector devices_; std::unordered_map global_group_rank_map_; // global_rank : group_rank @@ -116,12 +115,12 @@ class ProcessGroupNCCL final : public ProcessGroup { std::vector> ReduceAddCoalesced(const std::vector>> &grads, - const Device *destination) const override; + Device destination) const override; - std::vector> Scatter(const std::shared_ptr &tensor, - std::vector devices, int64_t dim) const override; + std::vector> Scatter(const std::shared_ptr &tensor, std::vector devices, + int64_t dim) const override; - std::shared_ptr Gather(const std::vector> &tensors, const Device *destination, + std::shared_ptr Gather(const std::vector> &tensors, Device destination, int64_t dim) const override; private: @@ -135,8 +134,8 @@ class ProcessGroupNCCL final : public ProcessGroup { std::vector comms_; std::vector comm_streams_; - std::unordered_map device_comm_map_; - std::unordered_map device_stream_map_; + std::unordered_map device_comm_map_; + std::unordered_map device_stream_map_; }; #endif diff --git a/infini_train/include/nn/parallel/work.h b/infini_train/include/nn/parallel/work.h index 1e11cc02..8cc60f78 100644 --- a/infini_train/include/nn/parallel/work.h +++ b/infini_train/include/nn/parallel/work.h @@ -12,9 +12,7 @@ #include #endif -namespace infini_train { -class Device; -} // namespace infini_train +#include "infini_train/include/device.h" namespace infini_train::nn::parallel { @@ -39,7 +37,7 @@ class Work { #ifdef USE_NCCL class WorkNccl final : public Work { public: - WorkNccl(const Device *device, ncclComm_t comm); + WorkNccl(Device device, ncclComm_t comm); ~WorkNccl() override; bool WaitBlocking(std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) override; @@ -60,7 +58,7 @@ class WorkNccl final : public Work { void SetException(std::exception_ptr e); private: - const Device *device_ = nullptr; + Device device_ = Device(); cudaEvent_t ready_event_; cudaEvent_t done_event_; ncclComm_t comm_; diff --git a/infini_train/include/profiler.h b/infini_train/include/profiler.h index bb54bfcc..6e0cf06f 100644 --- a/infini_train/include/profiler.h +++ b/infini_train/include/profiler.h @@ -17,12 +17,12 @@ inline thread_local int g_profiling_depth = 0; struct ProfileContext { std::string name; - DeviceType device; + Device::DeviceType device; }; inline thread_local ProfileContext g_profile_context; -inline void SetProfileContext(const std::string &name, DeviceType device) { +inline void SetProfileContext(const std::string &name, Device::DeviceType device) { if (g_profiling_depth == 0) { g_profile_context.name = name; g_profile_context.device = device; @@ -63,8 +63,8 @@ class Profiler { static Profiler &Instance(); - void StartRecord(const std::string &name, DeviceType device); - void EndRecord(const std::string &name, DeviceType device); + void StartRecord(const std::string &name, Device::DeviceType device); + void EndRecord(const std::string &name, Device::DeviceType device); void Report(std::ostream &os = std::cout, SortBy sort_by = SortBy::NotSorted) const; void Report(const std::string &file_path, SortBy sort_by = SortBy::NotSorted) const; @@ -84,17 +84,15 @@ class Profiler { std::vector call_records_; std::string current_tag_ = "Untagged"; + // thread-local tracking + thread_local static inline std::map cpu_timing_map_; + #ifdef USE_CUDA struct EventPair { void *start; void *stop; }; -#endif - // thread-local tracking - thread_local static inline std::map cpu_timing_map_; - -#ifdef USE_CUDA thread_local static inline std::map cuda_timing_map_; #endif }; diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index b499b604..b99e4044 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -10,8 +10,10 @@ #include "Eigen/Dense" #include "glog/logging.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/datatype.h" #include "infini_train/include/device.h" +#include "infini_train/include/dispatcher.h" namespace infini_train { namespace autograd { @@ -38,17 +40,17 @@ struct PrintOptions { class TensorBuffer { public: - TensorBuffer(const Device *device, size_t size); + TensorBuffer(Device device, size_t size); ~TensorBuffer(); void *DataPtr(); const void *DataPtr() const; - const Device *GetDevice() const; + Device GetDevice() const; size_t Size() const; private: - const Device *device_ = nullptr; + Device device_ = Device(); size_t size_ = 0; void *data_ = nullptr; }; @@ -57,17 +59,15 @@ class Tensor : public std::enable_shared_from_this { public: Tensor() = default; - Tensor(const std::vector &dims, DataType dtype, const Device *device); - Tensor(const std::vector &dims, DataType dtype) - : Tensor(dims, dtype, DeviceManager::Instance()->GetDevice(DeviceType::kCPU, 0)) {} + Tensor(const std::vector &dims, DataType dtype, Device device); + Tensor(const std::vector &dims, DataType dtype) : Tensor(dims, dtype, Device()) {} Tensor(const Tensor &tensor, size_t offset, const std::vector &dims); - Tensor(const float *data, const std::vector &dims, DataType dtype, const Device *device); - Tensor(const float *data, const std::vector &dims, DataType dtype) - : Tensor(data, dims, dtype, DeviceManager::Instance()->GetDevice(DeviceType::kCPU, 0)) {} + Tensor(const float *data, const std::vector &dims, DataType dtype, Device device); + Tensor(const float *data, const std::vector &dims, DataType dtype) : Tensor(data, dims, dtype, Device()) {} - const Device *GetDevice() const; + Device GetDevice() const; void *DataPtr(); const void *DataPtr() const; @@ -78,13 +78,29 @@ class Tensor : public std::enable_shared_from_this { size_t NumElements() const; DataType Dtype() const; - template void Fill(T value); + // TODO(dcj): use scalar class later + template void Fill(T value) { + auto device = GetDevice(); + core::DeviceGuard guard(device); + + DataType dtype = Dtype(); + + uint64_t storage = 0; + + DispatchFunc(Dtype(), [&storage, value]() { + TargetT casted_value = static_cast(value); + std::memcpy((void *)(&storage), &casted_value, sizeof(TargetT)); + }); + + auto kernel = Dispatcher::Instance().GetKernel({device.type(), "Fill"}); + kernel.Call(shared_from_this(), static_cast(&storage)); + } Eigen::Map> EigenMatrix(); Eigen::Map> EigenVector(); // TODO(dcj): return shared_ptr instead of Tensor later - Tensor To(const Device *device); + Tensor To(Device device); Tensor To(DataType dtype); void CopyFrom(const Tensor &src); diff --git a/infini_train/src/autograd/accumulate.cc b/infini_train/src/autograd/accumulate.cc index def9cad8..165147ef 100644 --- a/infini_train/src/autograd/accumulate.cc +++ b/infini_train/src/autograd/accumulate.cc @@ -3,6 +3,7 @@ #include "glog/logging.h" #include "infini_train/include/autograd/function_hook.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" @@ -22,7 +23,7 @@ AccumulateGrad::Backward(const std::vector> &grad_output auto grad = tensor_->grad(); auto device = tensor_->GetDevice(); - device->SetDevice(); + core::DeviceGuard guard(device); if (grad_output) { if (grad) { @@ -32,7 +33,7 @@ AccumulateGrad::Backward(const std::vector> &grad_output // NOTE(zbl): must copy, cannot change grad buffer address grad->CopyFrom(grad_output); } else { - auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "AccumulateGrad"}); + auto kernel = Dispatcher::Instance().GetKernel({device.type(), "AccumulateGrad"}); kernel.Call(grad_output, learning_rate_, grad); } } else { diff --git a/infini_train/src/autograd/activations.cc b/infini_train/src/autograd/activations.cc index 1706082b..3641865a 100644 --- a/infini_train/src/autograd/activations.cc +++ b/infini_train/src/autograd/activations.cc @@ -10,7 +10,7 @@ std::vector> Sigmoid::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "SigmoidForward"}, input)}; } @@ -26,7 +26,7 @@ std::vector> Sigmoid::Backward(const std::vectorGetDevice()->Type(); + auto device = output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "SigmoidBackward"}, output, grad_output)}; } } // namespace infini_train::autograd diff --git a/infini_train/src/autograd/comm.cc b/infini_train/src/autograd/comm.cc index 1bcad973..0e0028d0 100644 --- a/infini_train/src/autograd/comm.cc +++ b/infini_train/src/autograd/comm.cc @@ -10,7 +10,7 @@ namespace infini_train::autograd { -Scatter::Scatter(const std::vector &target_gpus, int64_t dim, +Scatter::Scatter(const std::vector &target_gpus, int64_t dim, const infini_train::nn::parallel::ProcessGroup *pg) : autograd::Function(kType), target_gpus_(target_gpus), dim_(dim), pg_(pg ? pg : infini_train::nn::parallel::ProcessGroupFactory::Instance()->GetDefaultProcessGroup()) {} @@ -18,7 +18,7 @@ Scatter::Scatter(const std::vector &target_gpus, int64_t dim, std::vector> Scatter::Forward(const std::vector> &input_tensors) { const auto &input = input_tensors[0]; std::vector> output_tensors; - auto device = input->GetDevice()->Type(); + auto device = input->GetDevice().type(); output_tensors = pg_->Scatter(input, target_gpus_, dim_); return output_tensors; } @@ -32,13 +32,13 @@ std::vector> Scatter::Backward(const std::vector(input_device_, dim_)->Apply(grad_outputs); } -Gather::Gather(const Device *target_device, int64_t dim, const infini_train::nn::parallel::ProcessGroup *pg) +Gather::Gather(Device target_device, int64_t dim, const infini_train::nn::parallel::ProcessGroup *pg) : autograd::Function(kType), target_device_(target_device), dim_(dim), pg_(pg ? pg : infini_train::nn::parallel::ProcessGroupFactory::Instance()->GetDefaultProcessGroup()) {} std::vector> Gather::Forward(const std::vector> &input_tensors) { for (const auto &tensor : input_tensors) { - CHECK_NE(static_cast(tensor->GetDevice()->Type()), static_cast(DeviceType::kCPU)) + CHECK_NE(static_cast(tensor->GetDevice().type()), static_cast(Device::DeviceType::kCPU)) << "Gather function not implemented for CPU tensors"; } if (dim_ == 0 && input_tensors[0]->Dims().size() == 0) { @@ -51,7 +51,7 @@ std::vector> Gather::Forward(const std::vectorGetDevice()->Type(); + auto device = input_tensors[0]->GetDevice().type(); return {pg_->Gather(input_tensors, target_device_, dim_)}; } @@ -62,10 +62,10 @@ void Gather::SetupContext(const std::vector> &input_tens std::vector> Gather::Backward(const std::vector> &grad_outputs) { // TODO(dcj): do squeeze here if unsqueezed_scalar_ is true - return std::make_shared(std::vector{input_gpus_}, dim_)->Apply(grad_outputs); + return std::make_shared(std::vector{input_gpus_}, dim_)->Apply(grad_outputs); } -Broadcast::Broadcast(const std::vector &target_gpus, const infini_train::nn::parallel::ProcessGroup *pg) +Broadcast::Broadcast(const std::vector &target_gpus, const infini_train::nn::parallel::ProcessGroup *pg) : autograd::Function(kType), target_gpus_(target_gpus), pg_(pg ? pg : infini_train::nn::parallel::ProcessGroupFactory::Instance()->GetDefaultProcessGroup()) {} @@ -78,7 +78,7 @@ std::vector> Broadcast::Forward(const std::vectorGetDevice()->IsCPU()) << "Broadcast function not implemented for CPU tensors"; - CHECK(tensor->GetDevice()->Type() == input_device_->Type()) + CHECK(tensor->GetDevice().type() == input_device_.type()) << "Broadcast function not implemented for tensors on different device type"; } @@ -95,7 +95,7 @@ std::vector> Broadcast::Backward(const std::vector(input_device_, num_inputs_)->Apply(grad_outputs); } -ReduceAddCoalesced::ReduceAddCoalesced(const Device *destination, int64_t num_inputs, +ReduceAddCoalesced::ReduceAddCoalesced(Device destination, int64_t num_inputs, const infini_train::nn::parallel::ProcessGroup *pg) : autograd::Function(kType), destination_(destination), num_inputs_(num_inputs), pg_(pg ? pg : infini_train::nn::parallel::ProcessGroupFactory::Instance()->GetDefaultProcessGroup()) {} diff --git a/infini_train/src/autograd/elementwise.cc b/infini_train/src/autograd/elementwise.cc index a00536a3..7291e284 100644 --- a/infini_train/src/autograd/elementwise.cc +++ b/infini_train/src/autograd/elementwise.cc @@ -10,7 +10,7 @@ std::vector> Neg::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "NegForward"}, input)}; } @@ -18,7 +18,7 @@ std::vector> Neg::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "NegBackward"}, grad_output)}; } @@ -26,7 +26,7 @@ std::vector> Reciprocal::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "ReciprocalForward"}, input)}; } @@ -42,7 +42,7 @@ std::vector> Reciprocal::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "ReciprocalBackward"}, grad_output, input)}; } @@ -50,7 +50,7 @@ std::vector> Sin::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "SinForward"}, input)}; } @@ -66,7 +66,7 @@ std::vector> Sin::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "SinBackward"}, grad_output, input)}; } @@ -74,7 +74,7 @@ std::vector> Cos::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "CosForward"}, input)}; } @@ -90,7 +90,7 @@ std::vector> Cos::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "CosBackward"}, grad_output, input)}; } @@ -98,7 +98,7 @@ std::vector> Tanh::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "TanhForward"}, input)}; } @@ -114,7 +114,7 @@ std::vector> Tanh::Backward(const std::vectorGetDevice()->Type(); + auto device = output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "TanhBackward"}, grad_output, output)}; } @@ -122,7 +122,7 @@ std::vector> Pow::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "PowForward"}, input, exponent_, scalar_is_base_)}; } @@ -139,7 +139,7 @@ std::vector> Pow::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "PowBackward"}, grad_output, input, exponent_, scalar_is_base_)}; } @@ -148,7 +148,7 @@ std::vector> Rsqrt::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "RsqrtForward"}, input)}; } @@ -164,7 +164,7 @@ std::vector> Rsqrt::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "RsqrtBackward"}, grad_output, input)}; } @@ -172,7 +172,7 @@ std::vector> Exp::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "ExpForward"}, input)}; } @@ -180,7 +180,7 @@ std::vector> Exp::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "ExpBackward"}, grad_output)}; } @@ -188,7 +188,7 @@ std::vector> Log::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "LogForward"}, input)}; } @@ -204,7 +204,7 @@ std::vector> Log::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "LogBackward"}, grad_output, input)}; } @@ -213,7 +213,7 @@ std::vector> Equals::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "EqualsForward"}, input, other)}; } @@ -226,7 +226,7 @@ std::vector> EqualsScalar::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "EqualsScalarForward"}, input, scalar_)}; } @@ -240,7 +240,7 @@ std::vector> Lt::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "LtForward"}, a, b)}; } @@ -253,7 +253,7 @@ std::vector> LtScalar::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "LtScalarForward"}, input, scalar_)}; } @@ -267,7 +267,7 @@ std::vector> Le::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "LeForward"}, a, b)}; } @@ -280,7 +280,7 @@ std::vector> LeScalar::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "LeScalarForward"}, input, scalar_)}; } @@ -294,7 +294,7 @@ std::vector> Gt::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "GtForward"}, a, b)}; } @@ -307,7 +307,7 @@ std::vector> GtScalar::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "GtScalarForward"}, input, scalar_)}; } @@ -321,7 +321,7 @@ std::vector> Ge::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "GeForward"}, a, b)}; } @@ -334,7 +334,7 @@ std::vector> GeScalar::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "GeScalarForward"}, input, scalar_)}; } @@ -348,7 +348,7 @@ std::vector> Or::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "OrForward"}, a, b)}; } @@ -362,7 +362,7 @@ std::vector> And::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "AndForward"}, a, b)}; } @@ -376,7 +376,7 @@ std::vector> Add::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "AddForward"}, a, b)}; } @@ -390,7 +390,7 @@ std::vector> Add::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); auto [grad_a, grad_b] = Dispatcher::Instance().Call, std::shared_ptr>>( {device, "AddBackward"}, grad_output, a_dims_, b_dims_); return {grad_a, grad_b}; @@ -400,7 +400,7 @@ std::vector> AddScalar::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "AddScalarForward"}, input, scalar_)}; } @@ -408,7 +408,7 @@ std::vector> AddScalar::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "AddScalarBackward"}, grad_output)}; } @@ -417,7 +417,7 @@ std::vector> Sub::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "SubForward"}, a, b)}; } @@ -431,7 +431,7 @@ std::vector> Sub::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); auto [grad_a, grad_b] = Dispatcher::Instance().Call, std::shared_ptr>>( {device, "SubBackward"}, grad_output, a_dims_, b_dims_); return {grad_a, grad_b}; @@ -442,7 +442,7 @@ std::vector> Mul::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MulForward"}, a, b)}; } @@ -460,7 +460,7 @@ std::vector> Mul::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); auto [grad_a, grad_b] = Dispatcher::Instance().Call, std::shared_ptr>>( {device, "MulBackward"}, grad_output, a, b); return {grad_a, grad_b}; @@ -470,7 +470,7 @@ std::vector> MulScalar::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MulScalarForward"}, input, scalar_)}; } @@ -478,7 +478,7 @@ std::vector> MulScalar::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MulScalarBackward"}, grad_output, scalar_)}; } @@ -487,7 +487,7 @@ std::vector> Div::Forward(const std::vectorGetDevice()->Type(); + auto device = a->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "DivForward"}, a, b)}; } @@ -505,7 +505,7 @@ std::vector> Div::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); auto [grad_a, grad_b] = Dispatcher::Instance().Call, std::shared_ptr>>( {device, "DivBackward"}, grad_output, a, b); return {grad_a, grad_b}; diff --git a/infini_train/src/autograd/function.cc b/infini_train/src/autograd/function.cc index 48ad02a9..097b8443 100644 --- a/infini_train/src/autograd/function.cc +++ b/infini_train/src/autograd/function.cc @@ -4,6 +4,7 @@ #include "infini_train/include/autograd/accumulate.h" #include "infini_train/include/autograd/grad_mode.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/device.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" @@ -12,9 +13,8 @@ namespace infini_train::autograd { std::vector> Function::Apply(const std::vector> &input_tensors) { CHECK_GE(input_tensors.size(), 1); - const auto *device = input_tensors[0]->GetDevice(); - // TODO(dcj): Cache context information to reduce setDevice overhead. - device->SetDevice(); + auto device = input_tensors[0]->GetDevice(); + core::DeviceGuard guard(device); std::vector> output_tensors; { @@ -60,8 +60,8 @@ std::vector> Function::Apply(const std::vector &grad_output, int grad_output_idx) { - const auto *device = grad_output->GetDevice(); - device->SetDevice(); + auto device = grad_output->GetDevice(); + core::DeviceGuard guard(device); // NOTE(dcj): The accumulate autograd function has no grad_outputs. // Temporarily resize the vector to hold one nullptr as a buffer. @@ -72,7 +72,7 @@ void Function::BackwardPartial(const std::shared_ptr &grad_output, int g grad_outputs_[grad_output_idx] = grad_output; ++grad_outputs_reached_; } else { - auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "AccumulateGrad"}); + auto kernel = Dispatcher::Instance().GetKernel({device.type(), "AccumulateGrad"}); kernel.Call(grad_output, 1.0f, grad_outputs_.at(grad_output_idx)); } ++dependencies_reached_; diff --git a/infini_train/src/autograd/linear.cc b/infini_train/src/autograd/linear.cc index 53330211..be397c32 100644 --- a/infini_train/src/autograd/linear.cc +++ b/infini_train/src/autograd/linear.cc @@ -12,7 +12,7 @@ std::vector> Linear::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "LinearForward"}, input, weight, true, bias)}; } @@ -32,7 +32,7 @@ std::vector> Linear::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); auto [grad_input, grad_weight, grad_bias] = Dispatcher::Instance() .Call, std::shared_ptr, std::shared_ptr>>( diff --git a/infini_train/src/autograd/loss.cc b/infini_train/src/autograd/loss.cc index 26e9957d..657ea649 100644 --- a/infini_train/src/autograd/loss.cc +++ b/infini_train/src/autograd/loss.cc @@ -11,7 +11,7 @@ std::vector> CrossEntropy::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "CrossEntropyForward"}, input, target)}; } @@ -29,7 +29,7 @@ std::vector> CrossEntropy::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); auto grad_input = Dispatcher::Instance().Call>({device, "CrossEntropyBackward"}, input, target, grad_output); return {grad_input, nullptr}; diff --git a/infini_train/src/autograd/matmul.cc b/infini_train/src/autograd/matmul.cc index 68136ba0..335396d6 100644 --- a/infini_train/src/autograd/matmul.cc +++ b/infini_train/src/autograd/matmul.cc @@ -11,7 +11,7 @@ std::vector> Matmul::Forward(const std::vectorGetDevice()->Type(); + auto device = input1->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MatmulForward"}, input1, input2)}; } @@ -31,7 +31,7 @@ std::vector> Matmul::Backward(const std::vectorGetDevice()->Type(); + auto device = input1->GetDevice().type(); auto [grad_input1, grad_input2] = Dispatcher::Instance().Call, std::shared_ptr>>( {device, "MatmulBackward"}, input1, input2, grad_output); diff --git a/infini_train/src/autograd/misc.cc b/infini_train/src/autograd/misc.cc index cdfba331..601258eb 100644 --- a/infini_train/src/autograd/misc.cc +++ b/infini_train/src/autograd/misc.cc @@ -10,7 +10,7 @@ std::vector> Split::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>>({device, "SplitForward"}, input, split_size_, dim_)}; } @@ -23,7 +23,7 @@ void Split::SetupContext(const std::vector> &input_tenso std::vector> Split::Backward(const std::vector> &grad_outputs) { auto device = grad_outputs[0]->GetDevice(); - return {Dispatcher::Instance().Call>({device->Type(), "SplitBackward"}, input_dims_, + return {Dispatcher::Instance().Call>({device.type(), "SplitBackward"}, input_dims_, split_size_, dim_, grad_outputs)}; } @@ -32,7 +32,7 @@ std::vector> IndexGather::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); auto kernel = Dispatcher::Instance().GetKernel({device, "IndexGatherForward"}); return {kernel.Call>(input, index, dim_)}; } @@ -51,7 +51,7 @@ std::vector> IndexGather::Backward(const std::vectorGetDevice(); - auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "IndexGatherBackward"}); + auto kernel = Dispatcher::Instance().GetKernel({device.type(), "IndexGatherBackward"}); return {kernel.Call>(grad_output, index, dim_, input_dims_)}; } @@ -59,7 +59,7 @@ std::vector> NoOp::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "NoOpForward"}, input, output_dims_)}; } @@ -73,7 +73,7 @@ std::vector> NoOp::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "NoOpBackward"}, input_dims_, grad_output)}; } @@ -81,7 +81,7 @@ std::vector> Slice::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return { Dispatcher::Instance().Call>({device, "SliceForward"}, input, starts_, ends_, steps_)}; } @@ -98,14 +98,14 @@ std::vector> Slice::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "SliceBackward"}, grad_output, input, starts_, ends_, steps_)}; } std::vector> Stack::Forward(const std::vector> &input_tensors) { CHECK_GE(input_tensors.size(), 2); - const auto device = input_tensors[0]->GetDevice()->Type(); + const auto device = input_tensors[0]->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "StackForward"}, input_tensors, dim_)}; } @@ -119,14 +119,14 @@ void Stack::SetupContext(const std::vector> &input_tenso std::vector> Stack::Backward(const std::vector> &grad_outputs) { const auto &grad_output = grad_outputs[0]; - auto device = grad_output->GetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>>({device, "StackBackward"}, input_dims_, dim_, grad_output)}; } std::vector> Concat::Forward(const std::vector> &input_tensors) { CHECK_GE(input_tensors.size(), 2); - const auto device = input_tensors[0]->GetDevice()->Type(); + const auto device = input_tensors[0]->GetDevice().type(); auto kernel = Dispatcher::Instance().GetKernel({device, "ConcatForward"}); return {kernel.Call>(input_tensors, dim_)}; @@ -140,7 +140,7 @@ void Concat::SetupContext(const std::vector> &input_tens std::vector> Concat::Backward(const std::vector> &grad_outputs) { const auto &grad_output = grad_outputs[0]; - auto device = grad_output->GetDevice()->Type(); + auto device = grad_output->GetDevice().type(); auto kernel = Dispatcher::Instance().GetKernel({device, "ConcatBackward"}); return kernel.Call>>(grad_output, input_dims_list_, dim_); } diff --git a/infini_train/src/autograd/normalization.cc b/infini_train/src/autograd/normalization.cc index 58d3bdc5..79a14abb 100644 --- a/infini_train/src/autograd/normalization.cc +++ b/infini_train/src/autograd/normalization.cc @@ -13,7 +13,7 @@ std::vector> LayerNorm::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); auto [output, mean, rstd] = Dispatcher::Instance() .Call, std::shared_ptr, std::shared_ptr>>( @@ -40,7 +40,7 @@ std::vector> LayerNorm::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); auto [grad_input, grad_weight, grad_bias] = Dispatcher::Instance() .Call, std::shared_ptr, std::shared_ptr>>( diff --git a/infini_train/src/autograd/outer.cc b/infini_train/src/autograd/outer.cc index 347df100..85a8c9ca 100644 --- a/infini_train/src/autograd/outer.cc +++ b/infini_train/src/autograd/outer.cc @@ -14,7 +14,7 @@ std::vector> Outer::Forward(const std::vectorGetDevice()->Type(); + auto device = input1->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "OuterForward"}, input1, input2)}; } @@ -32,7 +32,7 @@ std::vector> Outer::Backward(const std::vectorGetDevice()->Type(); + auto device = input1->GetDevice().type(); auto [grad_input1, grad_input2] = Dispatcher::Instance().Call, std::shared_ptr>>( {device, "OuterBackward"}, input1, input2, grad_output); diff --git a/infini_train/src/autograd/reduction.cc b/infini_train/src/autograd/reduction.cc index e5244947..5a6e086f 100644 --- a/infini_train/src/autograd/reduction.cc +++ b/infini_train/src/autograd/reduction.cc @@ -13,7 +13,7 @@ std::vector> Mean::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MeanForward"}, input, dim_, keep_dim_)}; } @@ -27,7 +27,7 @@ std::vector> Mean::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MeanBackward"}, grad_output, input_dims_, dim_, keep_dim_)}; } @@ -36,7 +36,7 @@ std::vector> Sum::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "SumForward"}, input, dim_, keep_dim_)}; } @@ -50,7 +50,7 @@ std::vector> Sum::Backward(const std::vectorGetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "SumBackward"}, grad_output, input_dims_, dim_, keep_dim_)}; } @@ -59,7 +59,7 @@ std::vector> Max::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MaxForward"}, input, dim_, keep_dim_)}; } @@ -77,7 +77,7 @@ std::vector> Max::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MaxBackward"}, grad_output, input, reduced, dim_, keep_dim_)}; } @@ -86,7 +86,7 @@ std::vector> Min::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MinForward"}, input, dim_, keep_dim_)}; } @@ -104,7 +104,7 @@ std::vector> Min::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MinBackward"}, grad_output, input, reduced, dim_, keep_dim_)}; } diff --git a/infini_train/src/autograd/softmax.cc b/infini_train/src/autograd/softmax.cc index 1987b6f7..39569a8c 100644 --- a/infini_train/src/autograd/softmax.cc +++ b/infini_train/src/autograd/softmax.cc @@ -10,7 +10,7 @@ std::vector> Softmax::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "SoftmaxForward"}, input, dim_)}; } @@ -26,7 +26,7 @@ std::vector> Softmax::Backward(const std::vectorGetDevice()->Type(); + auto device = output->GetDevice().type(); return { Dispatcher::Instance().Call>({device, "SoftmaxBackward"}, grad_output, output, dim_)}; } diff --git a/infini_train/src/autograd/sparse.cc b/infini_train/src/autograd/sparse.cc index 19867d55..93315b4f 100644 --- a/infini_train/src/autograd/sparse.cc +++ b/infini_train/src/autograd/sparse.cc @@ -11,7 +11,7 @@ std::vector> Embedding::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "EmbeddingForward"}, input, weight)}; } @@ -28,7 +28,7 @@ std::vector> Embedding::Backward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); auto grad_weight = Dispatcher::Instance().Call>({device, "EmbeddingBackward"}, input, weight_dims_, grad_output); return {nullptr, grad_weight}; diff --git a/infini_train/src/autograd/transform.cc b/infini_train/src/autograd/transform.cc index 3c33fea3..4fae05bb 100644 --- a/infini_train/src/autograd/transform.cc +++ b/infini_train/src/autograd/transform.cc @@ -8,14 +8,14 @@ std::vector> Tril::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "TrilForward"}, input, diagonal_)}; } std::vector> Tril::Backward(const std::vector> &grad_outputs) { const auto &grad_output = grad_outputs[0]; - auto device = grad_output->GetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "TrilBackward"}, grad_output, diagonal_)}; } @@ -23,14 +23,14 @@ std::vector> Triu::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "TriuForward"}, input, diagonal_)}; } std::vector> Triu::Backward(const std::vector> &grad_outputs) { const auto &grad_output = grad_outputs[0]; - auto device = grad_output->GetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "TriuBackward"}, grad_output, diagonal_)}; } @@ -38,14 +38,14 @@ std::vector> Transpose::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "TransposeForward"}, input, dim0_, dim1_)}; } std::vector> Transpose::Backward(const std::vector> &grad_outputs) { const auto &grad_output = grad_outputs[0]; - auto device = grad_output->GetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return { Dispatcher::Instance().Call>({device, "TransposeBackward"}, grad_output, dim0_, dim1_)}; } @@ -54,14 +54,14 @@ std::vector> Mask::Forward(const std::vectorGetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MaskForward"}, input, mask_, value_)}; } std::vector> Mask::Backward(const std::vector> &grad_outputs) { const auto &grad_output = grad_outputs[0]; - auto device = grad_output->GetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MaskBackward"}, grad_output, mask_)}; } @@ -70,7 +70,7 @@ RepeatInterleave::Forward(const std::vector> &input_tens CHECK_EQ(input_tensors.size(), 1); const auto &input = input_tensors[0]; - auto device = input->GetDevice()->Type(); + auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "RepeatInterleaveForward"}, input, repeat_, dim_)}; } @@ -85,7 +85,7 @@ std::vector> RepeatInterleave::Backward(const std::vector> &grad_outputs) { const auto &grad_output = grad_outputs[0]; - auto device = grad_output->GetDevice()->Type(); + auto device = grad_output->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "RepeatInterleaveBackward"}, grad_output, input_dims_, dim_)}; } diff --git a/infini_train/src/core/cpu/cpu_guard.cc b/infini_train/src/core/cpu/cpu_guard.cc new file mode 100644 index 00000000..6d98d30f --- /dev/null +++ b/infini_train/src/core/cpu/cpu_guard.cc @@ -0,0 +1,18 @@ +#include "infini_train/src/core/cpu/cpu_guard.h" + +#include +#include + +namespace infini_train::core::cpu { + +Device CpuGuardImpl::GetDevice() const { return Device(Device::DeviceType::kCPU, 0); } + +Device::DeviceType CpuGuardImpl::Type() const { return Device::DeviceType::kCPU; } + +void CpuGuardImpl::Malloc(void **dev_ptr, size_t size) { *dev_ptr = std::malloc(size); } + +void CpuGuardImpl::Free(void *dev_ptr) { std::free(dev_ptr); } + +void CpuGuardImpl::Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) { std::memcpy(dst, src, count); } + +} // namespace infini_train::core::cpu diff --git a/infini_train/src/core/cpu/cpu_guard.h b/infini_train/src/core/cpu/cpu_guard.h new file mode 100644 index 00000000..3b6ac71f --- /dev/null +++ b/infini_train/src/core/cpu/cpu_guard.h @@ -0,0 +1,22 @@ +#pragma once + +#include "infini_train/include/core/device_guard.h" + +namespace infini_train::core::cpu { + +class CpuGuardImpl : public DeviceGuardImpl { +public: + CpuGuardImpl(); + + Device GetDevice() const; + + Device::DeviceType Type() const; + + void Malloc(void **dev_ptr, size_t size); + + void Free(void *dev_ptr); + + void Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind); +}; + +} // namespace infini_train::core::cpu diff --git a/infini_train/src/core/cuda/cuda_blas_handle.cc b/infini_train/src/core/cuda/cuda_blas_handle.cc new file mode 100644 index 00000000..36da1eab --- /dev/null +++ b/infini_train/src/core/cuda/cuda_blas_handle.cc @@ -0,0 +1,17 @@ + +#include "infini_train/src/core/cuda/cuda_blas_handle.h" + +#include "infini_train/include/common/cuda/common_cuda.h" + +#include "infini_train/src/core/cuda/cuda_stream.h" + +namespace infini_train::core::cuda { + +CudaBlasHandle::CudaBlasHandle(Stream *stream) { + CUBLAS_CHECK(cublasCreate(&cublas_handle_)); + CUBLAS_CHECK(cublasSetStream(cublas_handle_, dynamic_cast(stream)->cuda_stream())); +} + +cublasHandle_t CudaBlasHandle::cublas_handle() const { return cublas_handle_; } + +} // namespace infini_train::core::cuda diff --git a/infini_train/src/core/cuda/cuda_blas_handle.h b/infini_train/src/core/cuda/cuda_blas_handle.h new file mode 100644 index 00000000..53678916 --- /dev/null +++ b/infini_train/src/core/cuda/cuda_blas_handle.h @@ -0,0 +1,23 @@ +#pragma once + +#include + +#include "infini_train/include/core/blas_handle.h" + +namespace infini_train::core { +class Stream; +} + +namespace infini_train::core::cuda { + +class CudaBlasHandle : public BlasHandle { +public: + explicit CudaBlasHandle(Stream *stream); + + cublasHandle_t cublas_handle() const; + +private: + cublasHandle_t cublas_handle_; +}; + +} // namespace infini_train::core::cuda diff --git a/infini_train/src/core/cuda/cuda_guard.cc b/infini_train/src/core/cuda/cuda_guard.cc new file mode 100644 index 00000000..ae0b34ef --- /dev/null +++ b/infini_train/src/core/cuda/cuda_guard.cc @@ -0,0 +1,131 @@ +#include "infini_train/src/core/cuda/cuda_guard.h" + +#include +#include +#include +#include + +#include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/blas_handle.h" +#include "infini_train/include/device.h" + +#include "infini_train/src/core/cuda/cuda_blas_handle.h" +#include "infini_train/src/core/cuda/cuda_stream.h" + +namespace infini_train::core::cuda { +namespace { +constexpr int kMaxGpus = 8; + +static std::array, kMaxGpus> cuda_streams; +static std::array, kMaxGpus> cuda_blas_handles; + +static std::array device_stream_flags; +static std::array device_handle_flags; +} // namespace + +void CudaGuardImpl::InitSingleStream(Device device) { + int current_device = -1; + CUDA_CHECK(cudaGetDevice(¤t_device)); + CUDA_CHECK(cudaSetDevice(device.index())); + + cuda_streams[device.index()] = std::make_unique(); + + CUDA_CHECK(cudaSetDevice(current_device)); +} + +void CudaGuardImpl::InitSingleHandle(Device device) { + int current_device = -1; + CUDA_CHECK(cudaGetDevice(¤t_device)); + CUDA_CHECK(cudaSetDevice(device.index())); + + std::call_once(device_stream_flags.at(device.index()), InitSingleStream, device.index()); + + cuda_blas_handles[device.index()] = std::make_unique(cuda_streams[device.index()].get()); + + CUDA_CHECK(cudaSetDevice(current_device)); +} + +CudaGuardImpl::CudaGuardImpl() {} + +// device +Device CudaGuardImpl::GetDevice() const { + int current_device = -1; + CUDA_CHECK(cudaGetDevice(¤t_device)); + return Device(Device::DeviceType::kCUDA, current_device); +} + +void CudaGuardImpl::SetDevice(Device device) const { CUDA_CHECK(cudaSetDevice(device.index())); } + +int8_t CudaGuardImpl::DeviceCount() const { + int device_count = 0; + CUDA_DRIVER_CHECK(cuDeviceGetCount(&device_count)); + return device_count; +} + +Device::DeviceType CudaGuardImpl::Type() const { return Device::DeviceType::kCUDA; } + +// stream +Stream *CudaGuardImpl::GetStream(Device device) const { + std::call_once(device_stream_flags.at(device.index()), InitSingleStream, device); + return cuda_streams.at(device.index()).get(); +} + +// event + +// sync +void CudaGuardImpl::SynchronizeDevice(Device device) const { + auto original_device = GetDevice(); + SetDevice(device); + + CUDA_CHECK(cudaDeviceSynchronize()); + + SetDevice(original_device); +} + +// blas +BlasHandle *CudaGuardImpl::GetBlasHandle(Device device) const { + std::call_once(device_handle_flags.at(device.index()), InitSingleStream, device); + return cuda_blas_handles.at(device.index()).get(); +} + +// memory +void CudaGuardImpl::Malloc(void **dev_ptr, size_t size) { CUDA_CHECK(cudaMalloc(dev_ptr, size)); } + +void CudaGuardImpl::MallocAsync(void **dev_ptr, size_t size, Stream *stream) { + CUDA_CHECK(cudaMallocAsync(dev_ptr, size, dynamic_cast(stream)->cuda_stream())); +} + +void CudaGuardImpl::Free(void *dev_ptr) { CUDA_CHECK(cudaFree(dev_ptr)); } + +void CudaGuardImpl::FreeAsync(void *dev_ptr, Stream *stream) { + CUDA_CHECK(cudaFreeAsync(dev_ptr, dynamic_cast(stream)->cuda_stream())); +} + +void CudaGuardImpl::Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) { + if (kind == MemcpyKind::kH2D) { + CUDA_CHECK(cudaMemcpy(dst, src, count, cudaMemcpyHostToDevice)); + } else if (kind == MemcpyKind::kD2H) { + CUDA_CHECK(cudaMemcpy(dst, src, count, cudaMemcpyDeviceToHost)); + } else if (kind == MemcpyKind::kD2D) { + CUDA_CHECK(cudaMemcpy(dst, src, count, cudaMemcpyDeviceToDevice)); + } else { + LOG(FATAL) << "Invalid MemcpyKind"; + } +} + +void CudaGuardImpl::MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream) { + cudaStream_t cuda_stream = dynamic_cast(stream)->cuda_stream(); + if (kind == MemcpyKind::kH2D) { + CUDA_CHECK(cudaMemcpyAsync(dst, src, count, cudaMemcpyHostToDevice, cuda_stream)); + } else if (kind == MemcpyKind::kD2H) { + CUDA_CHECK(cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToHost, cuda_stream)); + } else if (kind == MemcpyKind::kD2D) { + CUDA_CHECK(cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToDevice, cuda_stream)); + } else { + LOG(FATAL) << "Invalid MemcpyKind"; + } +} + +INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(Device::DeviceType::kCUDA, CudaGuardImpl) + +} // namespace infini_train::core::cuda diff --git a/infini_train/src/core/cuda/cuda_guard.h b/infini_train/src/core/cuda/cuda_guard.h new file mode 100644 index 00000000..e8360025 --- /dev/null +++ b/infini_train/src/core/cuda/cuda_guard.h @@ -0,0 +1,54 @@ +#pragma once + +#include + +#include "infini_train/include/core/blas_handle.h" +#include "infini_train/include/core/device_guard.h" +#include "infini_train/include/core/stream.h" +#include "infini_train/include/device.h" + +namespace infini_train::core::cuda { + +class CudaGuardImpl : public DeviceGuardImpl { +public: + static void InitSingleStream(Device device); + + static void InitSingleHandle(Device device); + + CudaGuardImpl(); + + // device + Device GetDevice() const override; + + void SetDevice(Device device) const override; + + int8_t DeviceCount() const override; + + Device::DeviceType Type() const override; + + // stream + Stream *GetStream(Device device) const override; + + // event + + // sync + void SynchronizeDevice(Device device) const override; + + // blas + BlasHandle *GetBlasHandle(Device device) const override; + + // memory + void Malloc(void **dev_ptr, size_t size) override; + + void MallocAsync(void **dev_ptr, size_t size, Stream *stream) override; + + void Free(void *dev_ptr) override; + + void FreeAsync(void *dev_ptr, Stream *stream) override; + + void Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) override; + + void MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream) override; +}; + +} // namespace infini_train::core::cuda diff --git a/infini_train/src/core/cuda/cuda_stream.cc b/infini_train/src/core/cuda/cuda_stream.cc new file mode 100644 index 00000000..82d04566 --- /dev/null +++ b/infini_train/src/core/cuda/cuda_stream.cc @@ -0,0 +1,12 @@ +#include "infini_train/src/core/cuda/cuda_stream.h" + +#include + +#include "infini_train/include/common/cuda/common_cuda.h" + +namespace infini_train::core::cuda { +CudaStream::CudaStream() { CUDA_CHECK(cudaStreamCreate(&stream_)); } + +cudaStream_t CudaStream::cuda_stream() const { return stream_; } + +} // namespace infini_train::core::cuda diff --git a/infini_train/src/core/cuda/cuda_stream.h b/infini_train/src/core/cuda/cuda_stream.h new file mode 100644 index 00000000..c5252097 --- /dev/null +++ b/infini_train/src/core/cuda/cuda_stream.h @@ -0,0 +1,19 @@ +#pragma once + +#include + +#include "infini_train/include/core/stream.h" + +namespace infini_train::core::cuda { + +class CudaStream : public Stream { +public: + CudaStream(); + + cudaStream_t cuda_stream() const; + +private: + cudaStream_t stream_; +}; + +} // namespace infini_train::core::cuda diff --git a/infini_train/src/core/device_guard.cc b/infini_train/src/core/device_guard.cc new file mode 100644 index 00000000..e6720dca --- /dev/null +++ b/infini_train/src/core/device_guard.cc @@ -0,0 +1,98 @@ +#include "infini_train/include/core/device_guard.h" + +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/core/blas_handle.h" +#include "infini_train/include/core/stream.h" +#include "infini_train/src/core/cpu/cpu_guard.h" + +namespace infini_train::core { + +// DeviceGuardImpl +int8_t DeviceGuardImpl::DeviceCount() const { return -1; } + +Stream *DeviceGuardImpl::GetStream(Device) const { return nullptr; } + +BlasHandle *DeviceGuardImpl::GetBlasHandle(Device) const { + LOG(FATAL) << "GetBlasHandle function not implemented for this device"; +} + +void DeviceGuardImpl::MallocAsync(void **dev_ptr, size_t size, Stream *stream) { + LOG(WARNING) << "MallocAsync is not supported on this device. Falling back to blocking Malloc()"; + Malloc(dev_ptr, size); +} + +void DeviceGuardImpl::FreeAsync(void *dev_ptr, Stream *stream) { + LOG(WARNING) << "FreeAsync is not supported on this device. Falling back to blocking Free()"; + Free(dev_ptr); +} + +void DeviceGuardImpl::MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream) { + LOG(WARNING) << "MemcpyAsync is not supported on this device. Falling back to blocking Memcpy()"; + Memcpy(dst, src, count, kind); +} + +// DeviceGuard +DeviceGuard::DeviceGuard(Device device) : impl_(GetDeviceGuardImpl(device.type())) { + original_device_ = impl_->GetDevice(); + impl_->SetDevice(device); +} + +void DeviceGuard::SetDevice(Device device) { + if (current_device_ == device) { + return; + } + impl_->SetDevice(device); + current_device_ = device; +} + +Device DeviceGuard::current_device() const { return current_device_; } + +Device DeviceGuard::original_device() const { return original_device_; } + +DeviceGuard::~DeviceGuard() { impl_->SetDevice(original_device_); } + +// DeviceGuardImplRegistry +DeviceGuardImplRegistry &DeviceGuardImplRegistry::Instance() { + static DeviceGuardImplRegistry instance; + instance.Register(Device::DeviceType::kCPU, std::make_unique()); + return instance; +} + +void DeviceGuardImplRegistry::Register(Device::DeviceType type, std::unique_ptr impl) { + if (type != impl->Type()) { + LOG(FATAL) << std::format("Register device guard impl with type {}, but as type {}", + static_cast(impl->Type()), static_cast(type)); + } + + if (impls_.contains(type)) { + LOG(FATAL) << std::format("DeviceGuardImpl for type {} already registrered", static_cast(type)); + } + + if (!impls_.empty()) { + for (auto &kv : impls_) { + if (kv.first != Device::DeviceType::kCPU) { + LOG(FATAL) << std::format("Only CPU and one GPU backend allowed. Already have GPU={}, new={} rejected.", + static_cast(kv.first), static_cast(type)); + } + } + } + + impls_[type] = std::move(impl); +} + +DeviceGuardImpl *DeviceGuardImplRegistry::Get(Device::DeviceType type) const { + auto it = impls_.find(type); + if (it == impls_.end()) { + LOG(FATAL) << "No DeviceGuardImpl registered for type " << static_cast(type); + } + return it->second.get(); +} + +DeviceGuardImpl *GetDeviceGuardImpl(Device::DeviceType type) { return DeviceGuardImplRegistry::Instance().Get(type); } + +} // namespace infini_train::core diff --git a/infini_train/src/device.cc b/infini_train/src/device.cc index 4271ff97..1bb3aaad 100644 --- a/infini_train/src/device.cc +++ b/infini_train/src/device.cc @@ -1,38 +1,40 @@ #include "infini_train/include/device.h" #include -#include +#include +#include +#include #include "glog/logging.h" #include "infini_train/include/nn/parallel/global.h" -#ifdef USE_CUDA -#include "infini_train/include/common/cuda/common_cuda.h" -#endif namespace infini_train { +Device::Device() : type_(DeviceType::kCPU), index_(0) {} + Device::Device(DeviceType type, int8_t index) : type_(type), index_(index) { if (type_ == DeviceType::kCPU && index_ != 0) { LOG(FATAL) << "CPU device index should be 0"; } } -DeviceType Device::Type() const { return type_; } -int8_t Device::Index() const { return index_; } +Device::DeviceType Device::type() const { return type_; } + +int8_t Device::index() const { return index_; } bool Device::IsCPU() const { return type_ == DeviceType::kCPU; } + bool Device::IsCUDA() const { return type_ == DeviceType::kCUDA; } std::string Device::ToString() const { std::ostringstream oss; - oss << "Device(" << (type_ == DeviceType::kCPU ? "CPU" : "CUDA") << ", " << static_cast(index_) << ")"; + oss << std::format("Device({}, {})", type_ == DeviceType::kCPU ? "CPU" : "CUDA", index_); return oss.str(); } -nn::parallel::Rank Device::rank() const { - LOG(FATAL) << "Unimplemented"; - // prevent the compiler warning about control reaching the end of non-void function - std::abort(); +nn::parallel::Rank Device::Rank() const { + return {nn::parallel::global::GetGlobalProcRank(), index_, nn::parallel::global::GetNprocPerNode(), + nn::parallel::global::GetNthreadPerProc()}; } std::ostream &operator<<(std::ostream &os, const Device &device) { @@ -40,71 +42,8 @@ std::ostream &operator<<(std::ostream &os, const Device &device) { return os; } -CpuDevice::CpuDevice() : Device(DeviceType::kCPU, 0) {} - -#ifdef USE_CUDA -CudaDevice::~CudaDevice() { - if (stream_ != nullptr) { - CUDA_CHECK(cudaStreamDestroy(stream_)); - } - - if (cublas_handle_ != nullptr) { - CUBLAS_CHECK(cublasDestroy(cublas_handle_)); - } -} - -void CudaDevice::SetDevice() const { CUDA_CHECK(cudaSetDevice(index_)); } -void CudaDevice::Synchronize() const { CUDA_CHECK(cudaDeviceSynchronize()); } - -cudaStream_t CudaDevice::Stream() const { return stream_; } - -cublasHandle_t CudaDevice::CublasHandle() const { return cublas_handle_; } - -nn::parallel::Rank CudaDevice::rank() const { return rank_; } - -CudaDevice::CudaDevice(int8_t index) - : Device(DeviceType::kCUDA, index), - rank_({nn::parallel::global::GetGlobalProcRank(), index, nn::parallel::global::GetNprocPerNode(), - nn::parallel::global::GetNthreadPerProc()}) { - // TODO(dcj): make CudaDevice initialization lazy to avoid allocating memory on all GPUs in single-GPU mode - SetDevice(); - CUDA_CHECK(cudaStreamCreate(&stream_)); - - CUBLAS_CHECK(cublasCreate(&cublas_handle_)); - CUBLAS_CHECK(cublasSetStream(cublas_handle_, stream_)); -} -#endif // USE_CUDA - -const DeviceManager *DeviceManager::Instance() { - static auto instance = std::unique_ptr(new DeviceManager()); - return instance.get(); -} - -const Device *DeviceManager::GetDevice(DeviceType type, int8_t index) const { - return devices_map_.at(type).at(index).get(); -} - -const Device *DeviceManager::GetDefaultDevice() const { return devices_map_.at(DeviceType::kCPU).at(0).get(); } - -std::vector DeviceManager::GetAllAvailableDevices(DeviceType device_type) const { - std::vector devices; - for (const auto &device : devices_map_.at(device_type)) { devices.push_back(device.get()); } - return devices; -} +bool operator==(const Device &a, const Device &b) { return a.type_ == b.type_ && a.index_ == b.index_; } -DeviceManager::DeviceManager() { - devices_map_[DeviceType::kCPU].push_back(std::unique_ptr(new CpuDevice())); -#ifdef USE_CUDA - CUDA_DRIVER_CHECK(cuInit(0)); - int device_count = 0; - CUDA_DRIVER_CHECK(cuDeviceGetCount(&device_count)); - int current_device = -1; - CUDA_CHECK(cudaGetDevice(¤t_device)); - for (int idx = 0; idx < device_count; ++idx) { - devices_map_[DeviceType::kCUDA].push_back(std::unique_ptr(new CudaDevice(idx))); - } - CUDA_CHECK(cudaSetDevice(current_device)); -#endif -} +bool operator!=(const Device &a, const Device &b) { return !(a == b); } } // namespace infini_train diff --git a/infini_train/src/kernels/cpu/accumulate_grad.cc b/infini_train/src/kernels/cpu/accumulate_grad.cc index 171d722c..cfe85b9c 100644 --- a/infini_train/src/kernels/cpu/accumulate_grad.cc +++ b/infini_train/src/kernels/cpu/accumulate_grad.cc @@ -37,7 +37,7 @@ void AdamAccumulateGrad(const std::shared_ptr &grad, const std::shared_p } // namespace infini_train::kernels::cpu #define REGISTER_CPU_ACCUMULATE_GRAD_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_ACCUMULATE_GRAD_KERNEL(AccumulateGrad) REGISTER_CPU_ACCUMULATE_GRAD_KERNEL(AdamAccumulateGrad) diff --git a/infini_train/src/kernels/cpu/cast.cc b/infini_train/src/kernels/cpu/cast.cc index 8481eb15..35f31214 100644 --- a/infini_train/src/kernels/cpu/cast.cc +++ b/infini_train/src/kernels/cpu/cast.cc @@ -24,7 +24,7 @@ std::shared_ptr Cast(std::shared_ptr input, DataType dtype) { } // namespace infini_train::kernels::cpu #define REGISTER_CPU_CAST_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_CAST_KERNEL(Cast) diff --git a/infini_train/src/kernels/cpu/concat.cc b/infini_train/src/kernels/cpu/concat.cc index d294eb85..b421063f 100644 --- a/infini_train/src/kernels/cpu/concat.cc +++ b/infini_train/src/kernels/cpu/concat.cc @@ -128,7 +128,7 @@ std::vector> ConcatBackward(const std::shared_ptr CrossEntropyBackward(const std::shared_ptr &inpu } // namespace infini_train::kernels::cpu #define REGISTER_CPU_CROSS_ENTROPY_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_CROSS_ENTROPY_KERNEL(CrossEntropyForward) REGISTER_CPU_CROSS_ENTROPY_KERNEL(CrossEntropyBackward) diff --git a/infini_train/src/kernels/cpu/elementwise.cc b/infini_train/src/kernels/cpu/elementwise.cc index 608172b6..8d66acd2 100644 --- a/infini_train/src/kernels/cpu/elementwise.cc +++ b/infini_train/src/kernels/cpu/elementwise.cc @@ -313,7 +313,7 @@ std::pair, std::shared_ptr> DivBackward(const st } // namespace infini_train::kernels::cpu #define REGISTER_CPU_ELEMENTWISE_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_ELEMENTWISE_KERNEL(NegForward) REGISTER_CPU_ELEMENTWISE_KERNEL(NegBackward) diff --git a/infini_train/src/kernels/cpu/embedding.cc b/infini_train/src/kernels/cpu/embedding.cc index 5debac9f..190c77c5 100644 --- a/infini_train/src/kernels/cpu/embedding.cc +++ b/infini_train/src/kernels/cpu/embedding.cc @@ -56,7 +56,7 @@ std::shared_ptr EmbeddingBackward(const std::shared_ptr &input, } // namespace infini_train::kernels::cpu #define REGISTER_CPU_EMBEDDING_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_EMBEDDING_KERNEL(EmbeddingForward) REGISTER_CPU_EMBEDDING_KERNEL(EmbeddingBackward) diff --git a/infini_train/src/kernels/cpu/fill.cc b/infini_train/src/kernels/cpu/fill.cc index 2e8fdbc7..175a15a2 100644 --- a/infini_train/src/kernels/cpu/fill.cc +++ b/infini_train/src/kernels/cpu/fill.cc @@ -12,7 +12,7 @@ void Fill(std::shared_ptr tensor, void *value_ptr) { } // namespace infini_train::kernels::cpu #define REGISTER_CPU_FILL_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_FILL_KERNEL(Fill) diff --git a/infini_train/src/kernels/cpu/gather.cc b/infini_train/src/kernels/cpu/gather.cc index c612efaa..9717b795 100644 --- a/infini_train/src/kernels/cpu/gather.cc +++ b/infini_train/src/kernels/cpu/gather.cc @@ -197,7 +197,7 @@ std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_ } // namespace infini_train::kernels::cpu #define REGISTER_CPU_GATHER_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_GATHER_KERNEL(IndexGatherForward) REGISTER_CPU_GATHER_KERNEL(IndexGatherBackward) diff --git a/infini_train/src/kernels/cpu/layernorm.cc b/infini_train/src/kernels/cpu/layernorm.cc index d717f348..c587f2c5 100644 --- a/infini_train/src/kernels/cpu/layernorm.cc +++ b/infini_train/src/kernels/cpu/layernorm.cc @@ -139,7 +139,7 @@ LayerNormBackward(const std::shared_ptr &input, const std::shared_ptr &input, const std::shared_ptr NoOpBackward(const std::vector &dims, const std } // namespace infini_train::kernels::cpu #define REGISTER_CPU_NO_OP_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_NO_OP_KERNEL(NoOpForward) REGISTER_CPU_NO_OP_KERNEL(NoOpBackward) diff --git a/infini_train/src/kernels/cpu/outer.cc b/infini_train/src/kernels/cpu/outer.cc index 2991dfd3..b61a3ed0 100644 --- a/infini_train/src/kernels/cpu/outer.cc +++ b/infini_train/src/kernels/cpu/outer.cc @@ -59,7 +59,7 @@ std::tuple, std::shared_ptr> OuterBackward(const } // namespace infini_train::kernels::cpu #define REGISTER_CPU_OUTER_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_OUTER_KERNEL(OuterForward) REGISTER_CPU_OUTER_KERNEL(OuterBackward) diff --git a/infini_train/src/kernels/cpu/reduction.cc b/infini_train/src/kernels/cpu/reduction.cc index 87ed5384..0aa936ba 100644 --- a/infini_train/src/kernels/cpu/reduction.cc +++ b/infini_train/src/kernels/cpu/reduction.cc @@ -169,7 +169,7 @@ std::shared_ptr MinBackward(const std::shared_ptr &grad_output, } // namespace infini_train::kernels::cpu #define REGISTER_CPU_REDUCTION_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_REDUCTION_KERNEL(MeanForward) REGISTER_CPU_REDUCTION_KERNEL(MeanBackward) diff --git a/infini_train/src/kernels/cpu/sigmoid.cc b/infini_train/src/kernels/cpu/sigmoid.cc index d4bc05da..8163a096 100644 --- a/infini_train/src/kernels/cpu/sigmoid.cc +++ b/infini_train/src/kernels/cpu/sigmoid.cc @@ -35,7 +35,7 @@ std::shared_ptr SigmoidBackward(const std::shared_ptr &output, } // namespace infini_train::kernels::cpu #define REGISTER_CPU_SIGMOID_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_SIGMOID_KERNEL(SigmoidForward) REGISTER_CPU_SIGMOID_KERNEL(SigmoidBackward) diff --git a/infini_train/src/kernels/cpu/slice.cc b/infini_train/src/kernels/cpu/slice.cc index 943b1c1b..bef925a7 100644 --- a/infini_train/src/kernels/cpu/slice.cc +++ b/infini_train/src/kernels/cpu/slice.cc @@ -130,7 +130,7 @@ std::shared_ptr SliceBackward(const std::shared_ptr &grad_output } // namespace infini_train::kernels::cpu #define REGISTER_CPU_SLICE_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_SLICE_KERNEL(SliceForward) REGISTER_CPU_SLICE_KERNEL(SliceBackward) diff --git a/infini_train/src/kernels/cpu/softmax.cc b/infini_train/src/kernels/cpu/softmax.cc index 454bdc2d..f711fbdc 100644 --- a/infini_train/src/kernels/cpu/softmax.cc +++ b/infini_train/src/kernels/cpu/softmax.cc @@ -81,7 +81,7 @@ std::shared_ptr SoftmaxBackward(const std::shared_ptr &grad_outp } // namespace infini_train::kernels::cpu #define REGISTER_CPU_SOFTMAX_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_SOFTMAX_KERNEL(SoftmaxForward) REGISTER_CPU_SOFTMAX_KERNEL(SoftmaxBackward) diff --git a/infini_train/src/kernels/cpu/split.cc b/infini_train/src/kernels/cpu/split.cc index e9a90ea9..209857f0 100644 --- a/infini_train/src/kernels/cpu/split.cc +++ b/infini_train/src/kernels/cpu/split.cc @@ -74,7 +74,7 @@ std::shared_ptr SplitBackward(const std::vector &input_dims, in } // namespace infini_train::kernels::cpu #define REGISTER_CPU_SPLIT_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_SPLIT_KERNEL(SplitForward) REGISTER_CPU_SPLIT_KERNEL(SplitBackward) diff --git a/infini_train/src/kernels/cpu/stack.cc b/infini_train/src/kernels/cpu/stack.cc index d1f71ed2..0ada6475 100644 --- a/infini_train/src/kernels/cpu/stack.cc +++ b/infini_train/src/kernels/cpu/stack.cc @@ -81,7 +81,7 @@ std::vector> StackBackward(const std::vector &i } // namespace infini_train::kernels::cpu #define REGISTER_CPU_STACK_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_STACK_KERNEL(StackForward) REGISTER_CPU_STACK_KERNEL(StackBackward) diff --git a/infini_train/src/kernels/cpu/transform.cc b/infini_train/src/kernels/cpu/transform.cc index 1c1697b0..00387917 100644 --- a/infini_train/src/kernels/cpu/transform.cc +++ b/infini_train/src/kernels/cpu/transform.cc @@ -219,7 +219,7 @@ std::shared_ptr RepeatInterleaveBackward(const std::shared_ptr & } // namespace infini_train::kernels::cpu #define REGISTER_CPU_TRANSFORM_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_TRANSFORM_KERNEL(TrilForward) REGISTER_CPU_TRANSFORM_KERNEL(TrilBackward) diff --git a/infini_train/src/kernels/cuda/accumulate_grad.cu b/infini_train/src/kernels/cuda/accumulate_grad.cu index 2b1d486c..54fa0ad2 100644 --- a/infini_train/src/kernels/cuda/accumulate_grad.cu +++ b/infini_train/src/kernels/cuda/accumulate_grad.cu @@ -2,8 +2,10 @@ #include #include "infini_train/include/common/cuda/kernel_helper.cuh" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { @@ -21,12 +23,15 @@ void AccumulateGrad(const std::shared_ptr &gradient, float rate, const s int threads_per_block = 256; int num_blocks = (num_elements + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(tensor->GetDevice()); + auto device = tensor->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( gradient->Dtype(), [=]() { - AccumulateGradKernel<<Stream()>>>( + AccumulateGradKernel<<>>( static_cast(gradient->DataPtr()), rate, static_cast(tensor->DataPtr()), num_elements); }, "CUDA AccumulateGrad"); @@ -61,12 +66,15 @@ void AdamAccumulateGrad(const std::shared_ptr &grad, const std::shared_p int threads_per_block = 256; int num_blocks = (num_elements + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(grad->GetDevice()); + + auto device = grad->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device))->cuda_stream(); DispatchFunc( grad->Dtype(), [=]() { - AdamAccumulateGradKernel<<Stream()>>>( + AdamAccumulateGradKernel<<>>( static_cast(grad->DataPtr()), static_cast(param->DataPtr()), num_elements, static_cast(m->DataPtr()), static_cast(v->DataPtr()), learning_rate, beta1, beta2, eps, bias_correction_m, bias_correction_v); @@ -76,7 +84,7 @@ void AdamAccumulateGrad(const std::shared_ptr &grad, const std::shared_p } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_ACCUMULATE_GRAD_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_ACCUMULATE_GRAD_KERNEL(AccumulateGrad) REGISTER_CUDA_ACCUMULATE_GRAD_KERNEL(AdamAccumulateGrad) diff --git a/infini_train/src/kernels/cuda/cast.cu b/infini_train/src/kernels/cuda/cast.cu index 6b53e8c8..e81d2a35 100644 --- a/infini_train/src/kernels/cuda/cast.cu +++ b/infini_train/src/kernels/cuda/cast.cu @@ -6,6 +6,7 @@ #include "infini_train/include/device.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { @@ -20,7 +21,10 @@ __global__ void CastKernel(Tdst *dst, const Tsrc *src, size_t num_elements, size std::shared_ptr Cast(std::shared_ptr input, DataType dtype) { auto dst_tensor = std::make_shared(input->Dims(), dtype, input->GetDevice()); - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); const size_t num_elements = input->NumElements(); dim3 block_dims(256); @@ -33,7 +37,7 @@ std::shared_ptr Cast(std::shared_ptr input, DataType dtype) { auto dst = static_cast(dst_tensor->DataPtr()); auto src = static_cast(input->DataPtr()); for (size_t offset = 0; offset < num_elements; offset += step) { - CastKernel<<Stream()>>>(dst, src, num_elements, offset); + CastKernel<<>>(dst, src, num_elements, offset); } }, "CUDA Cast"); @@ -43,7 +47,7 @@ std::shared_ptr Cast(std::shared_ptr input, DataType dtype) { } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_CAST_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_CAST_KERNEL(Cast) diff --git a/infini_train/src/kernels/cuda/comm.cu b/infini_train/src/kernels/cuda/comm.cu index c84cc068..c3063e99 100644 --- a/infini_train/src/kernels/cuda/comm.cu +++ b/infini_train/src/kernels/cuda/comm.cu @@ -12,7 +12,7 @@ namespace infini_train::kernels::cuda { std::vector> Broadcast(const std::vector> &input_tensors, - const std::vector &devices) { + const std::vector &devices) { std::vector> outputs; for (int i = 0; i < devices.size(); ++i) { for (const auto &tensor : input_tensors) { @@ -23,9 +23,9 @@ std::vector> Broadcast(const std::vector> ReduceAddCoalesced(const std::vector>> &grads, - const Device *destination) { + Device destination) { std::vector> outputs; - auto kernel = Dispatcher::Instance().GetKernel({destination->Type(), "AccumulateGrad"}); + auto kernel = Dispatcher::Instance().GetKernel({destination.type(), "AccumulateGrad"}); std::vector>> to_destination_grads; for (int i = 0; i < grads[0].size(); ++i) { outputs.emplace_back(std::make_shared(grads[0][i]->Dims(), grads[0][i]->Dtype(), destination)); @@ -45,7 +45,7 @@ std::vector> ReduceAddCoalesced(const std::vector> Scatter(const std::shared_ptr &tensor, std::vector devices, +std::vector> Scatter(const std::shared_ptr &tensor, std::vector devices, int64_t dim) { std::vector> outputs; // FIXME(dcj): do split without autograd @@ -56,22 +56,22 @@ std::vector> Scatter(const std::shared_ptr &tens return outputs; } -std::shared_ptr Gather(const std::vector> &tensors, const Device *destination, - int64_t dim) { +std::shared_ptr Gather(const std::vector> &tensors, Device destination, int64_t dim) { std::vector> outputs; for (const auto &tensor : tensors) { outputs.push_back(std::make_shared(tensor->To(destination))); } - auto kernel = Dispatcher::Instance().GetKernel({tensors[0]->GetDevice()->Type(), "StackForward"}); + auto kernel = Dispatcher::Instance().GetKernel({tensors[0]->GetDevice().type(), "StackForward"}); auto gathered_tensor = kernel.Call>(outputs, dim); auto old_dims = gathered_tensor->Dims(); std::vector new_dims{old_dims[0] * old_dims[1]}; for (int i = 2; i < old_dims.size(); ++i) { new_dims.push_back(old_dims[i]); } - auto view_kernel = Dispatcher::Instance().GetKernel({destination->Type(), "NoOpForward"}); + auto view_kernel = Dispatcher::Instance().GetKernel({destination.type(), "NoOpForward"}); return view_kernel.Call>(gathered_tensor, new_dims); } } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_COMM_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, Comm##kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, Comm##kernel_name, \ + infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_COMM_KERNEL(Broadcast) REGISTER_CUDA_COMM_KERNEL(Scatter) diff --git a/infini_train/src/kernels/cuda/concat.cu b/infini_train/src/kernels/cuda/concat.cu index f3d1730c..0d60844a 100644 --- a/infini_train/src/kernels/cuda/concat.cu +++ b/infini_train/src/kernels/cuda/concat.cu @@ -9,6 +9,7 @@ #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { __device__ __forceinline__ int64_t UpperBoundI64(const int64_t *offsets, int64_t n_plus_1, int64_t x) { @@ -91,8 +92,9 @@ std::shared_ptr ConcatForward(const std::vector> std::vector host_offsets(num_inputs + 1, 0); for (int64_t i = 0; i < num_inputs; ++i) { host_offsets[i + 1] = host_offsets[i] + Ks[i]; } - const auto *cuda_device = dynamic_cast(output->GetDevice()); - const auto &stream = cuda_device->Stream(); + const auto &stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); int64_t total = N * K_total * D; int threads_per_block = 256; @@ -175,10 +177,12 @@ std::vector> ConcatBackward(const std::shared_ptrGetDevice(); + std::vector> grads; grads.reserve(input_dims_list.size()); for (const auto &dvec : input_dims_list) { - auto t = std::make_shared(dvec, dtype, grad_output->GetDevice()); + auto t = std::make_shared(dvec, dtype, device); DispatchFunc( dtype, [=]() { t->Fill(0); }, "CUDA ConcatBackward"); grads.push_back(t); @@ -194,8 +198,9 @@ std::vector> ConcatBackward(const std::shared_ptr host_offsets(num_inputs + 1, 0); for (int64_t i = 0; i < num_inputs; ++i) { host_offsets[i + 1] = host_offsets[i] + Ks[i]; } - const auto *cuda_device = dynamic_cast(grad_output->GetDevice()); - const auto &stream = cuda_device->Stream(); + const auto &stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); int64_t total = N * K_total * D; int threads_per_block = 256; @@ -232,7 +237,7 @@ std::vector> ConcatBackward(const std::shared_ptr CrossEntropyForward(const std::shared_ptr &input constexpr int threads_per_block = 256; int num_blocks = bs; - const auto *cuda_device = dynamic_cast(target->GetDevice()); + auto device = target->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + return DispatchFunc, DataTypeList>( {target->Dtype(), input->Dtype()}, [=]() { @@ -91,12 +96,11 @@ std::shared_ptr CrossEntropyForward(const std::shared_ptr &input Tinput *batched_loss_ptr = static_cast(batched_output->DataPtr()); // FIXME(dcj): do reduce on GPU CrossEntropyForwardKernel - <<Stream()>>>(input_ptr, target_ptr, batched_loss_ptr, - bs, num_classes); + <<>>(input_ptr, target_ptr, batched_loss_ptr, bs, + num_classes); - auto loss_cpu = batched_output->To(DeviceManager::Instance()->GetDefaultDevice()); - auto loss = std::make_shared(std::vector{}, input->Dtype(), - DeviceManager::Instance()->GetDefaultDevice()); + auto loss_cpu = batched_output->To(Device()); + auto loss = std::make_shared(std::vector{}, input->Dtype(), Device()); auto loss_cpu_typed_ptr = static_cast(loss_cpu.DataPtr()); static_cast(loss->DataPtr())[0] = std::accumulate(loss_cpu_typed_ptr, loss_cpu_typed_ptr + bs, 0.0f, @@ -186,7 +190,11 @@ std::shared_ptr CrossEntropyBackward(const std::shared_ptr &inpu constexpr int threads_per_block = 256; int num_blocks = bs; - const auto *cuda_device = dynamic_cast(target->GetDevice()); + auto device = target->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + DispatchFunc, DataTypeList>( {target->Dtype(), input_casted->Dtype()}, [=]() { @@ -196,8 +204,8 @@ std::shared_ptr CrossEntropyBackward(const std::shared_ptr &inpu const Tinput *input_ptr = static_cast(input_casted->DataPtr()); Tinput *input_grad_ptr = static_cast(grad_input->DataPtr()); CrossEntropyBackwardKernel - <<Stream()>>>(input_ptr, input_grad_ptr, target_ptr, - output_grad_ptr, bs, num_classes); + <<>>(input_ptr, input_grad_ptr, target_ptr, + output_grad_ptr, bs, num_classes); }, "CUDA CrossEntropyBackward"); @@ -206,7 +214,7 @@ std::shared_ptr CrossEntropyBackward(const std::shared_ptr &inpu } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_CROSS_ENTROPY_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_CROSS_ENTROPY_KERNEL(CrossEntropyForward) REGISTER_CUDA_CROSS_ENTROPY_KERNEL(CrossEntropyBackward) diff --git a/infini_train/src/kernels/cuda/elementwise.cu b/infini_train/src/kernels/cuda/elementwise.cu index 913d848b..b3d65f30 100644 --- a/infini_train/src/kernels/cuda/elementwise.cu +++ b/infini_train/src/kernels/cuda/elementwise.cu @@ -6,6 +6,7 @@ #include "infini_train/include/common/cuda/kernel_helper.cuh" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { namespace { @@ -69,16 +70,18 @@ void LaunchKernel(Kernel &&kernel, const std::shared_ptr &output, const // Note: currently only support unary and binary operations template void LaunchForward(Func func, const std::shared_ptr &output, const Inputs &...inputs) { - const auto *cuda_device = dynamic_cast(output->GetDevice()); - const auto &stream = cuda_device->Stream(); + auto device = output->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); T *output_ptr = static_cast(output->DataPtr()); if constexpr (sizeof...(inputs) == 1) { // Unary case LaunchKernel( [&](dim3 grid, dim3 block, size_t offset, auto... ptrs) { - UnaryForwardKernel<<>>(output_ptr, func, output->NumElements(), offset, - ptrs...); + UnaryForwardKernel<<>>(output_ptr, func, output->NumElements(), offset, + ptrs...); }, output, inputs...); } else if constexpr (sizeof...(inputs) == 2) { @@ -488,14 +491,18 @@ __global__ void BinaryBackwardKernel(T *output_a, T *output_b, FuncA fn_a, FuncB template void LaunchBackward(Func func, const std::shared_ptr &output, const std::shared_ptr &grad_output, const Inputs &...inputs) { - const auto *cuda_device = dynamic_cast(output->GetDevice()); + auto device = output->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + T *output_ptr = static_cast(output->DataPtr()); const T *grad_ptr = static_cast(grad_output->DataPtr()); LaunchKernel( [=](dim3 grid, dim3 block, size_t offset, auto... ptrs) { - UnaryBackwardKernel<<Stream()>>>(output_ptr, func, output->NumElements(), - offset, grad_ptr, ptrs...); + UnaryBackwardKernel<<>>(output_ptr, func, output->NumElements(), offset, + grad_ptr, ptrs...); }, output, inputs...); } @@ -506,8 +513,11 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr &out const std::shared_ptr &output_b, const std::vector &a_dims, const std::vector &b_dims, const std::shared_ptr &grad_output, const Inputs &...inputs) { - const auto *cuda_device = dynamic_cast(output_a->GetDevice()); - const auto &stream = cuda_device->Stream(); + auto device = output_a->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + T *output_a_ptr = static_cast(output_a->DataPtr()); T *output_b_ptr = static_cast(output_b->DataPtr()); const T *grad_output_ptr = static_cast(grad_output->DataPtr()); @@ -1081,7 +1091,7 @@ std::shared_ptr SigmoidBackward(const std::shared_ptr &output, } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_ELEMENTWISE_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_ELEMENTWISE_KERNEL(NegForward) REGISTER_CUDA_ELEMENTWISE_KERNEL(NegBackward) diff --git a/infini_train/src/kernels/cuda/embedding.cu b/infini_train/src/kernels/cuda/embedding.cu index 6ae904f5..ddfeca98 100644 --- a/infini_train/src/kernels/cuda/embedding.cu +++ b/infini_train/src/kernels/cuda/embedding.cu @@ -3,6 +3,7 @@ #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { @@ -30,7 +31,11 @@ std::shared_ptr EmbeddingForward(const std::shared_ptr &input, c CHECK(input->Dtype() == DataType::kINT64); CHECK_EQ(weight->Dims().size(), 2); - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + const int batch_size = input->Dims().size() == 2 ? input->Dims()[0] : 1; const int max_seqlen = input->Dims().size() == 2 ? input->Dims()[1] : input->Dims()[0]; const int vocab_size = weight->Dims()[0]; @@ -46,7 +51,7 @@ std::shared_ptr EmbeddingForward(const std::shared_ptr &input, c DispatchFunc( dtype, [=]() { - EmbeddingForwardKernel<<Stream()>>>( + EmbeddingForwardKernel<<>>( static_cast(input->DataPtr()), static_cast(output->DataPtr()), static_cast(weight->DataPtr()), batch_size, max_seqlen, embed_dim, vocab_size); }, @@ -77,7 +82,11 @@ std::shared_ptr EmbeddingBackward(const std::shared_ptr &input, const std::shared_ptr &grad_output) { CHECK(input->Dtype() == DataType::kINT64); CHECK_EQ(weight_dims.size(), 2); - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + const int vocab_size = weight_dims[0]; const int embedding_dim = weight_dims[1]; CHECK_EQ(input->Dims().size() + 1, grad_output->Dims().size()); @@ -94,7 +103,7 @@ std::shared_ptr EmbeddingBackward(const std::shared_ptr &input, dtype, [=]() { grad_weight->Fill(0); - EmbeddingBackwardKernel<<Stream()>>>( + EmbeddingBackwardKernel<<>>( static_cast(input->DataPtr()), static_cast(grad_output->DataPtr()), static_cast(grad_weight->DataPtr()), num_tokens, embedding_dim, vocab_size); }, @@ -105,7 +114,7 @@ std::shared_ptr EmbeddingBackward(const std::shared_ptr &input, } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_EMBEDDING_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_EMBEDDING_KERNEL(EmbeddingForward) REGISTER_CUDA_EMBEDDING_KERNEL(EmbeddingBackward) diff --git a/infini_train/src/kernels/cuda/fill.cu b/infini_train/src/kernels/cuda/fill.cu index 2a601032..7874e5c3 100644 --- a/infini_train/src/kernels/cuda/fill.cu +++ b/infini_train/src/kernels/cuda/fill.cu @@ -4,6 +4,7 @@ #include "infini_train/include/device.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { @@ -19,12 +20,15 @@ void Fill(std::shared_ptr tensor, void *value_ptr) { const int num_tokens = tensor->NumElements(); const int threads_per_block = 256; const int num_blocks = (num_tokens + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(tensor->GetDevice()); + auto device = tensor->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( tensor->Dtype(), [=]() { - FillKernel<<Stream()>>>( + FillKernel<<>>( static_cast(tensor->DataPtr()), *(static_cast(value_ptr)), tensor->NumElements()); }, "CUDA Fill"); @@ -32,7 +36,7 @@ void Fill(std::shared_ptr tensor, void *value_ptr) { } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_FILL_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_FILL_KERNEL(Fill) diff --git a/infini_train/src/kernels/cuda/gather.cu b/infini_train/src/kernels/cuda/gather.cu index cc90d4a5..47898f3e 100644 --- a/infini_train/src/kernels/cuda/gather.cu +++ b/infini_train/src/kernels/cuda/gather.cu @@ -3,6 +3,7 @@ #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { // FIXME(zbl): This kernel aligns with torch.gather @@ -44,8 +45,8 @@ std::shared_ptr IndexGatherForward(const std::shared_ptr &input, const auto &in_dims = input->Dims(); const auto &idx_dims = index->Dims(); CHECK_EQ(in_dims.size(), idx_dims.size()); - CHECK(input->GetDevice()->Type() == index->GetDevice()->Type()); - CHECK(input->GetDevice()->Index() == index->GetDevice()->Index()); + CHECK(input->GetDevice().type() == index->GetDevice().type()); + CHECK(input->GetDevice().index() == index->GetDevice().index()); const int64_t num_dims = in_dims.size(); if (dim < 0) { @@ -66,11 +67,13 @@ std::shared_ptr IndexGatherForward(const std::shared_ptr &input, << "index.size(" << d << ") must be <= input.size(" << d << ") on non-gather dims"; } - const auto *cuda_dev = dynamic_cast(input->GetDevice()); - const auto &stream = cuda_dev->Stream(); + const auto device = input->GetDevice(); + const auto &stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); auto dtype = input->Dtype(); - auto out = std::make_shared(idx_dims, dtype, cuda_dev); + auto out = std::make_shared(idx_dims, dtype, device); auto in_strides = ComputeStrides(in_dims); auto out_strides = ComputeStrides(idx_dims); @@ -183,8 +186,11 @@ std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_ const size_t n_out_strides = idx_dims.size(); const size_t total_i64 = n_out + n_in_strides + n_out_strides; - const auto *cuda_device = dynamic_cast(grad_output->GetDevice()); - const auto &stream = cuda_device->Stream(); + auto device = grad_output->GetDevice(); + const auto &stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + CUDA_CHECK(cudaMallocAsync(&dev_buf, total_i64 * sizeof(int64_t), stream)); int64_t *out_dims_dev = dev_buf; int64_t *in_strides_dev = out_dims_dev + n_out; @@ -216,7 +222,7 @@ std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_ } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_GATHER_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_GATHER_KERNEL(IndexGatherForward) REGISTER_CUDA_GATHER_KERNEL(IndexGatherBackward) diff --git a/infini_train/src/kernels/cuda/layernorm.cu b/infini_train/src/kernels/cuda/layernorm.cu index ae825441..77b83509 100644 --- a/infini_train/src/kernels/cuda/layernorm.cu +++ b/infini_train/src/kernels/cuda/layernorm.cu @@ -5,6 +5,7 @@ #include "infini_train/include/device.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { @@ -77,13 +78,17 @@ LayerNormForward(const std::shared_ptr &input, const std::shared_ptr(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + DispatchFunc( dtype, [=]() { mean->Fill(0); rstd->Fill(0); - LayerNormForwardKernel<<Stream()>>>( + LayerNormForwardKernel<<>>( static_cast(input->DataPtr()), static_cast(weight->DataPtr()), static_cast(bias->DataPtr()), static_cast(mean->DataPtr()), static_cast(rstd->DataPtr()), static_cast(output->DataPtr()), eps, embed_dim); @@ -168,14 +173,17 @@ LayerNormBackward(const std::shared_ptr &input, const std::shared_ptr(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( dtype, [=]() { grad_input->Fill(0); grad_weight->Fill(0); grad_bias->Fill(0); - LayerNormBackwardKernel<<Stream()>>>( + LayerNormBackwardKernel<<>>( static_cast(input->DataPtr()), static_cast(grad_output->DataPtr()), static_cast(mean->DataPtr()), static_cast(rstd->DataPtr()), static_cast(weight->DataPtr()), static_cast(grad_input->DataPtr()), @@ -189,7 +197,7 @@ LayerNormBackward(const std::shared_ptr &input, const std::shared_ptr MatmulForward(const std::shared_ptr &input, cons output_dims[output_dims.size() - 1] = n; auto output = std::make_shared(output_dims, dtype, input->GetDevice()); - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); const float alpha = 1.0f, beta = 0.0f; - cublasHandle_t handle = cuda_device->CublasHandle(); + cublasHandle_t handle = dynamic_cast( + GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); // cuBLAS is colmun-major // output = input * other --> output.T = other.T * input.T @@ -129,9 +134,11 @@ MatmulBackward(const std::shared_ptr &input, const std::shared_ptr(input_promoted->GetDevice()); + auto device = input_promoted->GetDevice(); const float alpha = 1.0f, beta = 0.0f; - cublasHandle_t handle = cuda_device->CublasHandle(); + cublasHandle_t handle = dynamic_cast( + GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); { // cuBLAS is colmun-major @@ -230,7 +237,10 @@ std::shared_ptr LinearForward(const std::shared_ptr &input, cons *output_dims.rbegin() = out_features; auto output = std::make_shared(output_dims, dtype, input->GetDevice()); - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); if (bias) { CHECK_EQ(bias->Dims().size(), 1); @@ -241,7 +251,7 @@ std::shared_ptr LinearForward(const std::shared_ptr &input, cons DispatchFunc( dtype, [=]() { - BiasCopyKernel<<Stream()>>>( + BiasCopyKernel<<>>( static_cast(output->DataPtr()), static_cast(bias->DataPtr()), bs, out_features); }, "CUDA LinearForward"); @@ -255,7 +265,9 @@ std::shared_ptr LinearForward(const std::shared_ptr &input, cons auto trans_a = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; auto trans_b = CUBLAS_OP_N; auto lda = transpose ? in_features : out_features; - cublasHandle_t handle = cuda_device->CublasHandle(); + cublasHandle_t handle = dynamic_cast( + GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); // TODO(zbl): use cublasSgemv if possible for convenience and simplicity // @@ -353,7 +365,11 @@ LinearBackward(const std::shared_ptr &input, const std::shared_ptr( promoted_type, [=]() { initialize_gradients(T(0), promoted_type); }, "CUDA LinearBackward"); - const auto *cuda_device = dynamic_cast(input_promoted->GetDevice()); + auto device = input_promoted->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + float alpha = 1.0f; float beta = 0.0f; auto trans_a1 = transpose ? CUBLAS_OP_N : CUBLAS_OP_T; @@ -369,55 +385,57 @@ LinearBackward(const std::shared_ptr &input, const std::shared_ptrCublasHandle(); + cublasHandle_t handle = dynamic_cast( + GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); switch (promoted_type) { // TODO(zbl): use cublasSgemv if possible - DISPATCH_CASE( - WRAP({ - // - if transpose: - // weight is [out_features, in_features] here - // d_input = d_output * weight --> d_input.T = weight.T * d_output.T - // C = d_input.T[in_features, bs] - // A = weight.T[in_features, out_features] - // B = d_output.T[out_features, bs] - // - // - if not transpose: - // weight is [in_features, out_features] here - // d_input = d_output * weight.T --> d_input.T = weight * d_output.T - // C = d_input.T[in_features, bs] - // A = weight.T[out_features, in_features] - // B = d_output.T[out_features, bs] - CUBLAS_CHECK(cublasSgemm(handle, trans_a1, trans_b1, in_features, bs, out_features, &alpha, - static_cast(weight_promoted->DataPtr()), lda1, - static_cast(grad_output_promoted->DataPtr()), out_features, - &beta, static_cast(grad_input->DataPtr()), in_features)); - // - if transpose: - // d_weight = d_output.T * input --> d_weight.T = input.T * d_output - // C = d_weight.T[in_features, out_features] - // A = input.T[in_features, bs] - // B = d_output.T[out_features, bs] - // - // - if not transpose: - // d_weight = input.T * d_output --> d_weight.T = d_output.T * input - // C = d_weight.T[out_features, in_features] - // A = d_output.T[out_features, bs] - // B = input.T[in_features, bs] - CUBLAS_CHECK(cublasSgemm(handle, trans_a2, trans_b2, m2, n2, bs, &alpha, static_cast(a2), - lda2, static_cast(b2), ldb2, &beta, - static_cast(grad_weight->DataPtr()), ldc2)); - // d_bias = \sum_i(i=0, bs-1) d_output[i] - // TODO(dcj): use thrust::fill or reduce kernel do this - if (bias) { - constexpr int BLOCK_SIZE = 256; - int threads_per_block = BLOCK_SIZE; - int num_blocks = out_features; - ReduceColumnsKernel<<Stream()>>>( - static_cast(grad_output_promoted->DataPtr()), - static_cast(grad_bias->DataPtr()), out_features, bs); - } - }), - DataType::kFLOAT32) + DISPATCH_CASE(WRAP({ + // - if transpose: + // weight is [out_features, in_features] here + // d_input = d_output * weight --> d_input.T = weight.T * d_output.T + // C = d_input.T[in_features, bs] + // A = weight.T[in_features, out_features] + // B = d_output.T[out_features, bs] + // + // - if not transpose: + // weight is [in_features, out_features] here + // d_input = d_output * weight.T --> d_input.T = weight * d_output.T + // C = d_input.T[in_features, bs] + // A = weight.T[out_features, in_features] + // B = d_output.T[out_features, bs] + CUBLAS_CHECK(cublasSgemm(handle, trans_a1, trans_b1, in_features, bs, out_features, &alpha, + static_cast(weight_promoted->DataPtr()), lda1, + static_cast(grad_output_promoted->DataPtr()), + out_features, &beta, static_cast(grad_input->DataPtr()), + in_features)); + // - if transpose: + // d_weight = d_output.T * input --> d_weight.T = input.T * d_output + // C = d_weight.T[in_features, out_features] + // A = input.T[in_features, bs] + // B = d_output.T[out_features, bs] + // + // - if not transpose: + // d_weight = input.T * d_output --> d_weight.T = d_output.T * input + // C = d_weight.T[out_features, in_features] + // A = d_output.T[out_features, bs] + // B = input.T[in_features, bs] + CUBLAS_CHECK(cublasSgemm(handle, trans_a2, trans_b2, m2, n2, bs, &alpha, + static_cast(a2), lda2, static_cast(b2), + ldb2, &beta, static_cast(grad_weight->DataPtr()), ldc2)); + // d_bias = \sum_i(i=0, bs-1) d_output[i] + // TODO(dcj): use thrust::fill or reduce kernel do this + if (bias) { + constexpr int BLOCK_SIZE = 256; + int threads_per_block = BLOCK_SIZE; + int num_blocks = out_features; + ReduceColumnsKernel<<>>( + static_cast(grad_output_promoted->DataPtr()), + static_cast(grad_bias->DataPtr()), out_features, bs); + } + }), + DataType::kFLOAT32) DISPATCH_CASE(WRAP({ CUBLAS_CHECK(cublasGemmEx(handle, trans_a1, trans_b1, in_features, bs, out_features, &alpha, weight_promoted->DataPtr(), CUDA_R_16BF, lda1, @@ -431,10 +449,9 @@ LinearBackward(const std::shared_ptr &input, const std::shared_ptr - <<Stream()>>>( - static_cast(grad_output_promoted->DataPtr()), - static_cast(grad_bias->DataPtr()), out_features, bs); + ReduceColumnsKernel<<>>( + static_cast(grad_output_promoted->DataPtr()), + static_cast(grad_bias->DataPtr()), out_features, bs); } }), DataType::kBFLOAT16) @@ -445,7 +462,7 @@ LinearBackward(const std::shared_ptr &input, const std::shared_ptr NoOpBackward(const std::vector &dims, const std } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_NO_OP_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_NO_OP_KERNEL(NoOpForward) REGISTER_CUDA_NO_OP_KERNEL(NoOpBackward) diff --git a/infini_train/src/kernels/cuda/outer.cu b/infini_train/src/kernels/cuda/outer.cu index a0bcfe19..97db4b56 100644 --- a/infini_train/src/kernels/cuda/outer.cu +++ b/infini_train/src/kernels/cuda/outer.cu @@ -28,14 +28,16 @@ std::shared_ptr OuterForward(const std::shared_ptr &input, const auto output = std::make_shared(std::vector{M, N}, input->Dtype(), input->GetDevice()); - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); // reinterpret input: [M] as column vector [M, 1] // reinterpret other: [N] as row vector [1, N] // output[M, N] = input[M, 1] * other.T[1, N] // output.T[N, M] = other[N, 1] * input.T[1, M] float alpha = 1.0f; float beta = 0.0f; - cublasHandle_t handle = cuda_device->CublasHandle(); + cublasHandle_t handle = dynamic_cast( + GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); switch (input->Dtype()) { DISPATCH_CASE(WRAP({ @@ -97,10 +99,12 @@ std::tuple, std::shared_ptr> OuterBackward(const }, "CUDA OuterBackward"); - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); float alpha = 1.0f; float beta = 0.0f; - cublasHandle_t handle = cuda_device->CublasHandle(); + cublasHandle_t handle = dynamic_cast( + GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); switch (promoted_type) { DISPATCH_CASE(WRAP({ @@ -152,7 +156,7 @@ std::tuple, std::shared_ptr> OuterBackward(const } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_OUTER_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_OUTER_KERNEL(OuterForward) REGISTER_CUDA_OUTER_KERNEL(OuterBackward) diff --git a/infini_train/src/kernels/cuda/reduction.cu b/infini_train/src/kernels/cuda/reduction.cu index 7fd8c2c0..09519c45 100644 --- a/infini_train/src/kernels/cuda/reduction.cu +++ b/infini_train/src/kernels/cuda/reduction.cu @@ -2,8 +2,10 @@ #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/common/cuda/kernel_helper.cuh" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { namespace { @@ -132,14 +134,18 @@ std::shared_ptr ReduceOpForward(const std::shared_ptr &input, co int threads_per_block = BLOCK_SIZE; int num_blocks = N * W; - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + DispatchFunc( dtype, [=]() { GenericReduceKernel, BLOCK_SIZE> - <<Stream()>>>(static_cast(input->DataPtr()), - static_cast(output->DataPtr()), N, H, - W, FinalizeOp{}); + <<>>(static_cast(input->DataPtr()), + static_cast(output->DataPtr()), N, H, W, + FinalizeOp{}); }, "CUDA ReductionForward"); return output; @@ -164,12 +170,16 @@ std::shared_ptr ReduceOpBackward(const std::shared_ptr &grad_out int threads_per_block = 256; int num_blocks = (N * H * W + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(grad_output->GetDevice()); + auto device = grad_output->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + DispatchFunc( dtype, [=]() { grad_input->Fill(0); - GenericReduceBackwardKernel<<Stream()>>>( + GenericReduceBackwardKernel<<>>( static_cast(grad_input->DataPtr()), static_cast(grad_output->DataPtr()), input ? static_cast(input->DataPtr()) : nullptr, reduced ? static_cast(reduced->DataPtr()) : nullptr, N, H, W, is_mean, is_masked); @@ -217,7 +227,7 @@ std::shared_ptr MinBackward(const std::shared_ptr &grad_output, } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_REDUCTION_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_REDUCTION_KERNEL(MeanForward) REGISTER_CUDA_REDUCTION_KERNEL(SumForward) diff --git a/infini_train/src/kernels/cuda/slice.cu b/infini_train/src/kernels/cuda/slice.cu index 38d4aab6..4cf938db 100644 --- a/infini_train/src/kernels/cuda/slice.cu +++ b/infini_train/src/kernels/cuda/slice.cu @@ -4,8 +4,10 @@ #include "glog/logging.h" #include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { @@ -65,8 +67,11 @@ std::shared_ptr SliceForward(const std::shared_ptr &input, const int64_t *new_dims_dev, *starts_dev, *steps_dev, *input_strides_dev, *output_strides_dev; - const auto *cuda_device = dynamic_cast(input->GetDevice()); - const auto &stream = cuda_device->Stream(); + auto device = input->GetDevice(); + const auto &stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + cudaMallocAsync(&new_dims_dev, (ends.size() + starts.size() + steps.size() + dims.size() + new_dims.size()) * sizeof(int64_t), stream); @@ -157,8 +162,10 @@ std::shared_ptr SliceBackward(const std::shared_ptr &grad_output int dims_size = dims.size(); int64_t *new_dims_dev, *starts_dev, *steps_dev, *input_strides_dev, *output_strides_dev; - const auto *cuda_device = dynamic_cast(input->GetDevice()); - const auto &stream = cuda_device->Stream(); + auto device = input->GetDevice(); + const auto &stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); cudaMallocAsync(&new_dims_dev, (ends.size() + starts.size() + steps.size() + dims.size() + new_dims.size()) * sizeof(int64_t), stream); @@ -194,7 +201,7 @@ std::shared_ptr SliceBackward(const std::shared_ptr &grad_output } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_SLICE_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_SLICE_KERNEL(SliceForward) REGISTER_CUDA_SLICE_KERNEL(SliceBackward) diff --git a/infini_train/src/kernels/cuda/softmax.cu b/infini_train/src/kernels/cuda/softmax.cu index 622cb8a4..4289b0fa 100644 --- a/infini_train/src/kernels/cuda/softmax.cu +++ b/infini_train/src/kernels/cuda/softmax.cu @@ -7,8 +7,10 @@ #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/common/cuda/kernel_helper.cuh" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { template @@ -86,9 +88,12 @@ void LaunchForward(const std::shared_ptr &output, const std::shared_ptr< dim3 block_dims(BLOCK_SIZE); dim3 grid_dims(outer_size, inner_size); - const auto *cuda_device = dynamic_cast(output->GetDevice()); + auto device = output->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); SoftmaxForwardKernel - <<Stream()>>>(output_ptr, input_ptr, outer_size, axis_size, inner_size); + <<>>(output_ptr, input_ptr, outer_size, axis_size, inner_size); } std::shared_ptr SoftmaxForward(const std::shared_ptr &input, int64_t dim) { @@ -167,9 +172,12 @@ void LaunchBackward(const std::shared_ptr &grad_input, const std::shared dim3 block(BLOCK_SIZE); dim3 grid(outer_size, inner_size); - const auto *cuda_device = dynamic_cast(output->GetDevice()); - SoftmaxBackwardKernel<<Stream()>>>( - grad_input_ptr, grad_output_ptr, output_ptr, outer_size, axis_size, inner_size); + auto device = output->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + SoftmaxBackwardKernel<<>>(grad_input_ptr, grad_output_ptr, output_ptr, + outer_size, axis_size, inner_size); } std::shared_ptr SoftmaxBackward(const std::shared_ptr &grad_output, @@ -207,7 +215,7 @@ std::shared_ptr SoftmaxBackward(const std::shared_ptr &grad_outp } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_SOFTMAX_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_SOFTMAX_KERNEL(SoftmaxForward) REGISTER_CUDA_SOFTMAX_KERNEL(SoftmaxBackward) diff --git a/infini_train/src/kernels/cuda/split.cu b/infini_train/src/kernels/cuda/split.cu index ab22bf95..405cc967 100644 --- a/infini_train/src/kernels/cuda/split.cu +++ b/infini_train/src/kernels/cuda/split.cu @@ -50,12 +50,15 @@ std::vector> SplitForward(const std::shared_ptr int threads_per_block = 256; int num_blocks = (total + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( dtype, [=]() { - SplitForwardKernel<<Stream()>>>( + SplitForwardKernel<<>>( static_cast(input->DataPtr()), static_cast(output->DataPtr()), N, H_in, H_out, W, start); }, @@ -114,8 +117,10 @@ std::shared_ptr LaunchSplitBackward(const std::vector &input_di int64_t H_in = input_dims[dim]; int64_t num_splits = grad_outputs.size(); - const auto *cuda_device = dynamic_cast(grad->GetDevice()); - const auto &stream = cuda_device->Stream(); + auto device = grad->GetDevice(); + const auto &stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); // init the array of grad_output ptrs std::vector host_grad_output_ptrs; for (const auto &grad_output : grad_outputs) { @@ -165,7 +170,7 @@ std::shared_ptr SplitBackward(const std::vector &input_dims, in } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_SPLIT_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_SPLIT_KERNEL(SplitForward) REGISTER_CUDA_SPLIT_KERNEL(SplitBackward) diff --git a/infini_train/src/kernels/cuda/stack.cu b/infini_train/src/kernels/cuda/stack.cu index 5fe4899c..f8f81738 100644 --- a/infini_train/src/kernels/cuda/stack.cu +++ b/infini_train/src/kernels/cuda/stack.cu @@ -7,8 +7,10 @@ #include "glog/logging.h" #include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { template @@ -48,8 +50,10 @@ std::shared_ptr StackForward(const std::vector> const int64_t D = std::accumulate(base_dims.begin() + dim, base_dims.end(), 1, std::multiplies()); const int64_t num_inputs = inputs.size(); - const auto *cuda_device = dynamic_cast(output->GetDevice()); - const auto &stream = cuda_device->Stream(); + auto device = output->GetDevice(); + const auto &stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); int64_t total = N * num_inputs * D; int threads_per_block = 256; @@ -115,8 +119,10 @@ std::vector> StackBackward(const std::vector &i int64_t N = std::accumulate(input_dims.begin(), input_dims.begin() + dim, 1, std::multiplies()); int64_t D = std::accumulate(input_dims.begin() + dim, input_dims.end(), 1, std::multiplies()); - const auto *cuda_device = dynamic_cast(grad_output->GetDevice()); - const auto &stream = cuda_device->Stream(); + auto device = grad_output->GetDevice(); + const auto &stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); int64_t total = N * num_inputs * D; int threads_per_block = 256; @@ -145,7 +151,7 @@ std::vector> StackBackward(const std::vector &i } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_STACK_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_STACK_KERNEL(StackForward) REGISTER_CUDA_STACK_KERNEL(StackBackward) diff --git a/infini_train/src/kernels/cuda/transform.cu b/infini_train/src/kernels/cuda/transform.cu index 7f1f818d..e1831b1f 100644 --- a/infini_train/src/kernels/cuda/transform.cu +++ b/infini_train/src/kernels/cuda/transform.cu @@ -6,8 +6,10 @@ #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/common/cuda/kernel_helper.cuh" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { @@ -38,12 +40,15 @@ std::shared_ptr TrilForward(const std::shared_ptr &input, int64_ int threads_per_block = 256; int num_blocks = (rows * cols + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( input->Dtype(), [=]() { - TrilForwardKernel<<Stream()>>>( + TrilForwardKernel<<>>( static_cast(input->DataPtr()), static_cast(output->DataPtr()), rows, cols, diagonal); }, "CUDA TrilForward"); @@ -78,13 +83,16 @@ std::shared_ptr TrilBackward(const std::shared_ptr &grad_output, int threads_per_block = 256; int num_blocks = (rows * cols + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(grad_output->GetDevice()); + auto device = grad_output->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( dtype, [=]() { grad_input->Fill(0); - TrilBackwardKernel<<Stream()>>>( + TrilBackwardKernel<<>>( static_cast(grad_output->DataPtr()), static_cast(grad_input->DataPtr()), rows, cols, diagonal); }, @@ -120,12 +128,15 @@ std::shared_ptr TriuForward(const std::shared_ptr &input, int64_ int threads_per_block = 256; int num_blocks = (rows * cols + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( input->Dtype(), [=]() { - TriuForwardKernel<<Stream()>>>( + TriuForwardKernel<<>>( static_cast(input->DataPtr()), static_cast(output->DataPtr()), rows, cols, diagonal); }, "CUDA TriuForward"); @@ -159,13 +170,16 @@ std::shared_ptr TriuBackward(const std::shared_ptr &grad_output, int threads_per_block = 256; int num_blocks = (rows * cols + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(grad_output->GetDevice()); + auto device = grad_output->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( dtype, [=]() { grad_input->Fill(0); - TriuBackwardKernel<<Stream()>>>( + TriuBackwardKernel<<>>( static_cast(grad_output->DataPtr()), static_cast(grad_input->DataPtr()), rows, cols, diagonal); }, @@ -229,8 +243,10 @@ std::shared_ptr TransposeForward(const std::shared_ptr &input, i out_strides[i] = out_strides[i + 1] * out_dims[i + 1]; } - const auto *cuda_device = dynamic_cast(input->GetDevice()); - const auto &stream = cuda_device->Stream(); + auto device = input->GetDevice(); + const auto &stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); // Allocate device memory for dims and strides // TODO(zbl): avoid using cudaMalloc? @@ -341,7 +357,11 @@ std::shared_ptr MaskForward(const std::shared_ptr &input, const MaskMode mode = DecideMaskMode(input_shape, mask_shape); auto output = std::make_shared(input_shape, dtype, input->GetDevice()); - const auto *cuda_device = dynamic_cast(output->GetDevice()); + auto device = output->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + int threads_per_block = 256; if (mode == MaskMode::kLead) { @@ -352,7 +372,7 @@ std::shared_ptr MaskForward(const std::shared_ptr &input, const DispatchFunc( dtype, [=]() { - MaskLeadsForwardKernel<<Stream()>>>( + MaskLeadsForwardKernel<<>>( static_cast(input->DataPtr()), static_cast(mask->DataPtr()), static_cast(output->DataPtr()), common::cuda::Cast(value), rows, inner); }, @@ -365,7 +385,7 @@ std::shared_ptr MaskForward(const std::shared_ptr &input, const DispatchFunc( dtype, [=]() { - MaskForwardKernel<<Stream()>>>( + MaskForwardKernel<<>>( static_cast(input->DataPtr()), static_cast(mask_casted->DataPtr()), static_cast(output->DataPtr()), common::cuda::Cast(value), static_cast(batch_size), static_cast(mask_size)); @@ -401,7 +421,11 @@ std::shared_ptr MaskBackward(const std::shared_ptr &grad_output, MaskMode mode = DecideMaskMode(output_shape, mask_shape); auto grad_input = std::make_shared(output_shape, dtype, grad_output->GetDevice()); - const auto *cuda_device = dynamic_cast(grad_output->GetDevice()); + auto device = grad_output->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + int threads_per_block = 256; if (mode == MaskMode::kLead) { @@ -413,7 +437,7 @@ std::shared_ptr MaskBackward(const std::shared_ptr &grad_output, dtype, [=]() { grad_input->Fill(0); - MaskLeadsBackwardKernel<<Stream()>>>( + MaskLeadsBackwardKernel<<>>( static_cast(grad_output->DataPtr()), static_cast(mask_casted->DataPtr()), static_cast(grad_input->DataPtr()), rows, inner); }, @@ -427,7 +451,7 @@ std::shared_ptr MaskBackward(const std::shared_ptr &grad_output, dtype, [=]() { grad_input->Fill(0); - MaskBackwardKernel<<Stream()>>>( + MaskBackwardKernel<<>>( static_cast(grad_output->DataPtr()), static_cast(mask_casted->DataPtr()), static_cast(grad_input->DataPtr()), static_cast(batch_size), static_cast(mask_size)); }, @@ -473,12 +497,15 @@ std::shared_ptr RepeatInterleaveForward(const std::shared_ptr &i int64_t total_elements = outer * dim_size * repeat * inner; int threads_per_block = 256; int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( input->Dtype(), [=]() { - RepeatInterleaveForwardKernel<<Stream()>>>( + RepeatInterleaveForwardKernel<<>>( static_cast(input->DataPtr()), static_cast(output->DataPtr()), outer, dim_size, inner, repeat); }, @@ -528,13 +555,16 @@ std::shared_ptr RepeatInterleaveBackward(const std::shared_ptr & int64_t total_elements = outer * dim_size * inner; int threads_per_block = 256; int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block; - const auto *cuda_device = dynamic_cast(grad_output->GetDevice()); + auto cuda_device = grad_output->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); DispatchFunc( grad_output->Dtype(), [=]() { grad_input->Fill(0); - RepeatInterleaveBackwardKernel<<Stream()>>>( + RepeatInterleaveBackwardKernel<<>>( static_cast(grad_output->DataPtr()), static_cast(grad_input->DataPtr()), outer, dim_size, inner, repeat); }, @@ -545,7 +575,7 @@ std::shared_ptr RepeatInterleaveBackward(const std::shared_ptr & } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_TRANSFORM_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_TRANSFORM_KERNEL(TrilForward) REGISTER_CUDA_TRANSFORM_KERNEL(TrilBackward) diff --git a/infini_train/src/kernels/cuda/vocab_parallel_cross_entropy.cu b/infini_train/src/kernels/cuda/vocab_parallel_cross_entropy.cu index 8b5d4450..d289403d 100644 --- a/infini_train/src/kernels/cuda/vocab_parallel_cross_entropy.cu +++ b/infini_train/src/kernels/cuda/vocab_parallel_cross_entropy.cu @@ -4,8 +4,10 @@ #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/common/cuda/kernel_helper.cuh" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::kernels::cuda { @@ -74,10 +76,13 @@ VocabParallelCrossEntropyBackward(const std::shared_ptr &grad_output, dloss_is_scalar = (grad_output->NumElements() == 1); } - const auto *cuda_device = dynamic_cast(grad_output->GetDevice()); + auto device = grad_output->GetDevice(); + const auto &cuda_stream + = dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); // logits should be [rows, V_local] - auto grad_input = std::make_shared(softmax_local->Dims(), softmax_local->Dtype(), cuda_device); + auto grad_input = std::make_shared(softmax_local->Dims(), softmax_local->Dtype(), device); const float one_minus_label_smoothing = 1.0f - label_smoothing; const float smoothing_term = (label_smoothing > 0.f && vocab_size_original > 0) @@ -100,10 +105,10 @@ VocabParallelCrossEntropyBackward(const std::shared_ptr &grad_output, Tinput *grad_input_ptr = static_cast(grad_input->DataPtr()); VocabParallelCrossEntropyBackwardKernel - <<Stream()>>>( - softmax_ptr, grad_input_ptr, mtarget_ptr, tmask_ptr, vml_ptr, grad_output_ptr, - static_cast(rows), static_cast(vocab_size_local), dloss_is_scalar, - one_minus_label_smoothing, smoothing_term); + <<>>(softmax_ptr, grad_input_ptr, mtarget_ptr, tmask_ptr, + vml_ptr, grad_output_ptr, static_cast(rows), + static_cast(vocab_size_local), dloss_is_scalar, + one_minus_label_smoothing, smoothing_term); }, "CUDA VocabParallelCrossEntropyBackward"); @@ -112,7 +117,7 @@ VocabParallelCrossEntropyBackward(const std::shared_ptr &grad_output, } // namespace infini_train::kernels::cuda #define REGISTER_CUDA_VOCAB_PARALLEL_CROSS_ENTROPY_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_VOCAB_PARALLEL_CROSS_ENTROPY_KERNEL(VocabParallelCrossEntropyBackward) diff --git a/infini_train/src/nn/init.cc b/infini_train/src/nn/init.cc index e00e2f8a..2f769aec 100644 --- a/infini_train/src/nn/init.cc +++ b/infini_train/src/nn/init.cc @@ -16,8 +16,10 @@ #include "glog/logging.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/device.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::nn::init { namespace { @@ -46,26 +48,12 @@ std::shared_ptr Normal(const std::shared_ptr &tensor, float mean #endif auto device = tensor->GetDevice(); - device->SetDevice(); + core::DeviceGuard guard(device); + auto impl = core::GetDeviceGuardImpl(device.type()); - switch (device->Type()) { - case DeviceType::kCPU: { - memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float)); - break; - } -#ifdef USE_CUDA - case DeviceType::kCUDA: { - // TODO(dcj): maybe use async API later? - cudaMemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), cudaMemcpyHostToDevice, - dynamic_cast(device)->Stream()); - break; - } -#endif - default: { - LOG(FATAL) << "Unsupported device type: " << static_cast(tensor->GetDevice()->Type()); - break; - } - } + impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, + impl->GetStream(device)); return tensor; } @@ -152,26 +140,14 @@ std::shared_ptr Uniform(const std::shared_ptr &tensor, float a, #endif auto device = tensor->GetDevice(); - device->SetDevice(); - switch (device->Type()) { - case DeviceType::kCPU: { - memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float)); - break; - } -#ifdef USE_CUDA - case DeviceType::kCUDA: { - // TODO(dcj): maybe use async API later? - cudaMemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), cudaMemcpyHostToDevice, - dynamic_cast(device)->Stream()); - break; - } -#endif - default: { - LOG(FATAL) << "Unsupported device type: " << static_cast(tensor->GetDevice()->Type()); - break; - } - } + core::DeviceGuard guard(device); + auto impl = core::GetDeviceGuardImpl(device.type()); + + impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, + impl->GetStream(device)); + return tensor; } @@ -182,26 +158,14 @@ std::shared_ptr Ones(const std::shared_ptr &tensor) { std::vector buffer(num_elements, 1.0f); auto device = tensor->GetDevice(); - device->SetDevice(); + core::DeviceGuard guard(device); + + auto impl = core::GetDeviceGuardImpl(device.type()); + + impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, + impl->GetStream(device)); - switch (device->Type()) { - case DeviceType::kCPU: { - memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float)); - break; - } -#ifdef USE_CUDA - case DeviceType::kCUDA: { - // TODO(dcj): maybe use async API later? - cudaMemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), cudaMemcpyHostToDevice, - dynamic_cast(device)->Stream()); - break; - } -#endif - default: { - LOG(FATAL) << "Unsupported device type: " << static_cast(tensor->GetDevice()->Type()); - break; - } - } return tensor; } @@ -212,26 +176,14 @@ std::shared_ptr Zeros(const std::shared_ptr &tensor) { std::vector buffer(num_elements, 0.0f); auto device = tensor->GetDevice(); - device->SetDevice(); + core::DeviceGuard guard(device); + + auto impl = core::GetDeviceGuardImpl(device.type()); + + impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, + impl->GetStream(device)); - switch (device->Type()) { - case DeviceType::kCPU: { - memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float)); - break; - } -#ifdef USE_CUDA - case DeviceType::kCUDA: { - // TODO(dcj): maybe use async API later? - cudaMemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), cudaMemcpyHostToDevice, - dynamic_cast(device)->Stream()); - break; - } -#endif - default: { - LOG(FATAL) << "Unsupported device type: " << static_cast(tensor->GetDevice()->Type()); - break; - } - } return tensor; } @@ -246,17 +198,19 @@ std::shared_ptr Zeros(const std::shared_ptr &tensor) { case DATA_TYPE: { \ std::vector buffer(num_elements); \ std::iota(buffer.begin(), buffer.end(), static_cast(start)); \ - cudaMemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(TYPE), cudaMemcpyHostToDevice, \ - dynamic_cast(device)->Stream()); \ - break; \ - } + cudaMemcpyAsync( \ + tensor->DataPtr(), buffer.data(), num_elements * sizeof(TYPE), cudaMemcpyHostToDevice, \ + dynamic_cast(GetDeviceGuardImpl(device.type())->GetStream(device)) \ + ->cuda_stream()); \ + break; +} // namespace infini_train::nn::init -std::shared_ptr Arange(int64_t start, int64_t end, DataType dtype, const Device *device) { +std::shared_ptr Arange(int64_t start, int64_t end, DataType dtype, Device device) { int64_t num_elements = end - start; auto tensor = std::make_shared(std::vector{num_elements}, dtype, device); - device->SetDevice(); + core::DeviceGuard guard(device); - if (device->IsCPU()) { + if (device.IsCPU()) { switch (dtype) { CASE(DataType::kUINT8, uint8_t) CASE(DataType::kINT8, int8_t) @@ -294,7 +248,7 @@ std::shared_ptr Arange(int64_t start, int64_t end, DataType dtype, const break; } #else - LOG(FATAL) << "Unsupported device type: " << static_cast(device->Type()); + LOG(FATAL) << "Unsupported device type: " << static_cast(device.type()); #endif } return tensor; diff --git a/infini_train/src/nn/modules/linear.cc b/infini_train/src/nn/modules/linear.cc index 67c0d733..e5a58d01 100644 --- a/infini_train/src/nn/modules/linear.cc +++ b/infini_train/src/nn/modules/linear.cc @@ -10,9 +10,9 @@ #include "infini_train/include/tensor.h" namespace infini_train::nn { -Linear::Linear(int64_t in_features, int64_t out_features, bool bias, const Device *device) +Linear::Linear(int64_t in_features, int64_t out_features, bool bias, Device device) : CloneableModule(kType), bias_(bias) { - device_ = device ? device : DeviceManager::Instance()->GetDefaultDevice(); + device_ = device ? device : Device(); parameters_[kParamWeightName] = std::make_shared(std::vector{out_features, in_features}, DataType::kFLOAT32, device_) diff --git a/infini_train/src/nn/modules/module.cc b/infini_train/src/nn/modules/module.cc index 4e0c6a28..e4c1ab98 100644 --- a/infini_train/src/nn/modules/module.cc +++ b/infini_train/src/nn/modules/module.cc @@ -14,7 +14,7 @@ namespace infini_train::nn { Module::Module() : Module(kUndefinedType) {} -Module::Module(const std::string &type) : type_(type), device_(DeviceManager::Instance()->GetDefaultDevice()) {} +Module::Module(const std::string &type) : type_(type), device_(Device()) {} const std::string &Module::type() const { return type_; } @@ -125,8 +125,7 @@ std::vector> Module::Forward(const std::vector &normalized_shape, float eps, const Device *device) : eps_(eps) { - device_ = device ? device : DeviceManager::Instance()->GetDefaultDevice(); +LayerNorm::LayerNorm(const std::vector &normalized_shape, float eps, Device device) : eps_(eps) { + device_ = device ? device : Device(); parameters_[kParamWeightName] = std::make_shared(normalized_shape, DataType::kFLOAT32, device_)->RequiresGrad(); diff --git a/infini_train/src/nn/modules/sparse.cc b/infini_train/src/nn/modules/sparse.cc index ab845697..9314fe6d 100644 --- a/infini_train/src/nn/modules/sparse.cc +++ b/infini_train/src/nn/modules/sparse.cc @@ -10,9 +10,8 @@ namespace infini_train::nn { -Embedding::Embedding(int num_embeddings, int embedding_dim, const Device *device) : CloneableModule(kType) { - device_ = device ? device : DeviceManager::Instance()->GetDefaultDevice(); - +Embedding::Embedding(int num_embeddings, int embedding_dim, Device device) : CloneableModule(kType) { + device_ = device; parameters_[kParamWeightName] = std::make_shared(std::vector{num_embeddings, embedding_dim}, DataType::kFLOAT32, device_) ->RequiresGrad(); diff --git a/infini_train/src/nn/parallel/data_parallel.cc b/infini_train/src/nn/parallel/data_parallel.cc index 0dec0c8f..77c23986 100644 --- a/infini_train/src/nn/parallel/data_parallel.cc +++ b/infini_train/src/nn/parallel/data_parallel.cc @@ -8,8 +8,10 @@ #include "glog/logging.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/device.h" #include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/parallel_functional.h" #include "infini_train/include/tensor.h" @@ -19,8 +21,7 @@ constexpr char kModuleName[] = "module"; std::vector>> ParallelApply(const std::vector> &modules, - const std::vector>> &inputs, - const std::vector &devices) { + const std::vector>> &inputs, const std::vector &devices) { CHECK_EQ(modules.size(), inputs.size()) << std::format( "The number of modules {} is not equal to the number of inputs {}", modules.size(), inputs.size()); CHECK_EQ(modules.size(), devices.size()); @@ -29,8 +30,8 @@ ParallelApply(const std::vector> &modules, std::vector>>> results(modules.size(), std::nullopt); auto worker = [&](const std::shared_ptr &module, const std::vector> &inputs, - const Device *device, int idx) { - device->SetDevice(); + Device device, int idx) { + core::DeviceGuard guard(device); auto output = module->Forward(inputs); results[idx] = output; }; @@ -57,8 +58,10 @@ ParallelApply(const std::vector> &modules, } } // namespace -DataParallel::DataParallel(const std::shared_ptr &module, int dim) - : dim_(dim), devices_(DeviceManager::Instance()->GetAllAvailableDevices(DeviceType::kCUDA)) { +DataParallel::DataParallel(const std::shared_ptr &module, int dim, Device::DeviceType device_type) : dim_(dim) { + devices_.reserve(global::GetNthreadPerProc()); + for (int index = 0; index < global::GetNthreadPerProc(); ++index) { devices_.emplace_back(device_type, index); } + CHECK_GT(devices_.size(), 0) << "No available devices found"; output_device_ = devices_.at(0); src_device_ = devices_.at(0); diff --git a/infini_train/src/nn/parallel/distributed_data_parallel.cc b/infini_train/src/nn/parallel/distributed_data_parallel.cc index a25a7d16..b40122e4 100644 --- a/infini_train/src/nn/parallel/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/distributed_data_parallel.cc @@ -21,7 +21,7 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr mod const ReducerOptions &opts) { for (auto ¶m : module->Parameters()) { auto device = param->GetDevice(); - CHECK_EQ(device->Index(), device_id) << "All parameters must be on the same device as the module"; + CHECK_EQ(device.index(), device_id) << "All parameters must be on the same device as the module"; if (!opts.gradient_bucketing_enabled) { auto ddp_pg = ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(device->rank().thread_rank())); @@ -31,7 +31,7 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr mod } } for (auto &buffer : module->Buffers()) { - CHECK_EQ(buffer->GetDevice()->Index(), device_id) << "All buffers must be on the same device as the module"; + CHECK_EQ(buffer->GetDevice().index(), device_id) << "All buffers must be on the same device as the module"; } modules_[kModuleName] = std::move(module); diff --git a/infini_train/src/nn/parallel/parallel_functional.cc b/infini_train/src/nn/parallel/parallel_functional.cc index 50408949..719365b4 100644 --- a/infini_train/src/nn/parallel/parallel_functional.cc +++ b/infini_train/src/nn/parallel/parallel_functional.cc @@ -15,7 +15,7 @@ namespace infini_train::nn::parallel::function { std::shared_ptr AllReduce(const std::shared_ptr &tensor, ReduceOpType reduce_op, const ProcessGroup *pg, bool async_op) { - auto device = tensor->GetDevice()->Type(); + auto device = tensor->GetDevice().type(); if (pg == nullptr) { pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); } @@ -24,7 +24,7 @@ std::shared_ptr AllReduce(const std::shared_ptr &tensor, ReduceOpT std::shared_ptr AllGather(const std::shared_ptr &output, const std::shared_ptr &input, const ProcessGroup *pg, bool async_op) { - auto device = output->GetDevice()->Type(); + auto device = output->GetDevice().type(); if (pg == nullptr) { pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); } @@ -33,7 +33,7 @@ std::shared_ptr AllGather(const std::shared_ptr &output, const std std::shared_ptr ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, ReduceOpType reduce_op, const ProcessGroup *pg, bool async_op) { - auto device = output->GetDevice()->Type(); + auto device = output->GetDevice().type(); if (pg == nullptr) { pg = ProcessGroupFactory::Instance()->GetDefaultProcessGroup(); } @@ -41,7 +41,7 @@ std::shared_ptr ReduceScatter(const std::shared_ptr &output, const } std::vector>> Scatter(const std::vector> &input_tensors, - const std::vector &devices, int dim) { + const std::vector &devices, int dim) { std::vector>> output_tensors; for (const auto &tensor : input_tensors) { output_tensors.emplace_back(std::make_shared(devices, dim)->Apply({tensor})); @@ -56,15 +56,14 @@ std::vector>> Scatter(const std::vector> Gather(const std::vector>> &tensors, - const Device *target_device, int dim) { + Device target_device, int dim) { std::vector> gather_tensors; for (const auto &tensor : tensors) { gather_tensors.push_back(tensor[0]); } return std::make_shared(target_device, dim)->Apply(gather_tensors); } std::vector>> -BroadcastCoalescedReshape(const std::vector> &tensors, - const std::vector &devices) { +BroadcastCoalescedReshape(const std::vector> &tensors, const std::vector &devices) { if (tensors.empty()) { return {}; } @@ -80,7 +79,7 @@ BroadcastCoalescedReshape(const std::vector> &tensors, } std::vector> Replicate(const std::shared_ptr &network, - const std::vector &devices) { + const std::vector &devices) { const int num_replicas = devices.size(); // FIXME(dcj): Parameters function need deduplication diff --git a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc index 95dd3bbc..f109708f 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc @@ -213,7 +213,7 @@ float PipelineSchedule::StepMicroBatches(const std::vectordevice()->Type(), dtype); + infini_train::AutocastGuard autocast_guard(stage_->device().type(), dtype); std::vector> inputs; @@ -241,15 +241,14 @@ float PipelineSchedule::StepMicroBatches(const std::vector loss; { - infini_train::AutocastGuard autocast_guard(stage_->device()->Type(), dtype); + infini_train::AutocastGuard autocast_guard(stage_->device().type(), dtype); auto target_on_device = target->To(activations[task.local_chunk_idx][mb][0]->GetDevice()); loss = loss_fn->Forward( {activations[task.local_chunk_idx][mb][0], std::make_shared(target_on_device)})[0]; loss = loss / n; } - total_loss - += static_cast(loss->To(DeviceManager::Instance()->GetDefaultDevice()).DataPtr())[0]; + total_loss += static_cast(loss->To(Device()).DataPtr())[0]; loss->Backward(); } else { diff --git a/infini_train/src/nn/parallel/pp/pipeline_stage.cc b/infini_train/src/nn/parallel/pp/pipeline_stage.cc index 582b9bd2..46a1b0bf 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_stage.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_stage.cc @@ -12,12 +12,10 @@ namespace infini_train::nn::parallel { PipelineStage::PipelineStage(int stage_index /* pp_rank */, int num_stages /* pp_size */, const std::vector> &recv_shape, std::shared_ptr optimizer, - int device_id, std::vector> &&chunks) + Device device, std::vector> &&chunks) : stage_index_(stage_index), num_stages_(num_stages), prev_rank_(stage_index > 0 ? stage_index - 1 : -1), next_rank_(stage_index < num_stages - 1 ? stage_index + 1 : -1), recv_shape_(recv_shape), - optimizer_(std::move(optimizer)), - device_(DeviceManager::Instance()->GetAllAvailableDevices(DeviceType::kCUDA).at(device_id)), - chunks_(std::move(chunks)) {} + optimizer_(std::move(optimizer)), device_(device), chunks_(std::move(chunks)) {} std::vector> PipelineStage::ForwardOneChunk(const std::vector> &inputs, int local_chunk_idx) { @@ -36,7 +34,7 @@ int PipelineStage::prev_rank() const { return prev_rank_; } int PipelineStage::next_rank() const { return next_rank_; } int PipelineStage::num_stages() const { return num_stages_; } -const Device *PipelineStage::device() const { return device_; } +Device PipelineStage::device() const { return device_; } const std::vector> &PipelineStage::recv_shape() const { return recv_shape_; } std::shared_ptr PipelineStage::optimizer() { return optimizer_; } const std::vector> &PipelineStage::chunks() { return chunks_; } diff --git a/infini_train/src/nn/parallel/pp/send_recv.cc b/infini_train/src/nn/parallel/pp/send_recv.cc index bac71f0b..afcdaac2 100644 --- a/infini_train/src/nn/parallel/pp/send_recv.cc +++ b/infini_train/src/nn/parallel/pp/send_recv.cc @@ -18,8 +18,7 @@ class ISend : public autograd::Function { public: static constexpr char kType[] = "ISendFunction"; - explicit ISend(const Device *target_device, int cur_rank, int peer_rank, - const std::vector> &shape) + explicit ISend(Device target_device, int cur_rank, int peer_rank, const std::vector> &shape) : autograd::Function(kType), target_device_(target_device), cur_rank_(cur_rank), peer_rank_(peer_rank), shapes_(shape) {} @@ -28,8 +27,8 @@ class ISend : public autograd::Function { std::vector> Backward(const std::vector> &grad_outputs) override; private: - const Device *target_device_ = nullptr; - const Device *input_device_ = nullptr; + Device target_device_ = Device(); + Device input_device_ = Device(); int cur_rank_ = -1; int peer_rank_ = -1; const std::vector> &shapes_; @@ -39,7 +38,7 @@ class IRecv : public autograd::Function { public: static constexpr char kType[] = "IRecvFunction"; - explicit IRecv(const Device *src_device, int cur_rank, int peer_rank) + explicit IRecv(Device src_device, int cur_rank, int peer_rank) : autograd::Function(kType), src_device_(src_device), cur_rank_(cur_rank), peer_rank_(peer_rank) {} std::vector> Forward(const std::vector> &input_tensors) override; @@ -50,8 +49,8 @@ class IRecv : public autograd::Function { std::vector> Backward(const std::vector> &grad_outputs) override; private: - const Device *src_device_ = nullptr; - const Device *cur_device_ = nullptr; + Device src_device_ = Device(); + Device cur_device_ = Device(); int cur_rank_ = -1; int peer_rank_ = -1; }; @@ -112,14 +111,14 @@ std::vector> IRecv::Backward(const std::vector> ISend(const std::vector> &input_tensors, - const Device *target_device, int cur_rank, int peer_rank, + Device target_device, int cur_rank, int peer_rank, const std::vector> &shape) { auto func = std::make_shared(target_device, cur_rank, peer_rank, shape); return func->Apply(input_tensors); } -std::vector> IRecv(const std::vector> &outputs, - const Device *src_device, int cur_rank, int peer_rank) { +std::vector> IRecv(const std::vector> &outputs, Device src_device, + int cur_rank, int peer_rank) { auto func = std::make_shared(src_device, cur_rank, peer_rank); return func->Apply(outputs); } diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index 50a75d48..2489bb37 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -18,11 +18,13 @@ #ifdef USE_CUDA #include "infini_train/include/common/cuda/common_cuda.h" #endif +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/datatype.h" #include "infini_train/include/device.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/work.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train { @@ -128,10 +130,10 @@ void ProcessGroupNCCL::InitSingleProcess(const std::vector &ranks) { NCCL_CHECK(ncclCommInitAll(comms_.data(), world_size_, ranks.data())); for (int i = 0; i < ranks.size(); ++i) { - auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, ranks[i]); + auto device = Device(Device::DeviceType::kCUDA, ranks[i]); devices_.push_back(device); device_comm_map_[device] = comms_[i]; - global_group_rank_map_[device->rank().GlobalRank()] = i; + global_group_rank_map_[device.Rank().GlobalRank()] = i; } } @@ -165,8 +167,8 @@ void ProcessGroupNCCL::InitMultiProcess(const std::vector &ranks) { NCCL_CHECK(ncclCommInitRank(&comm, world_size_, nccl_id, group_rank)); comms_.push_back(comm); - auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, i); - global_group_rank_map_[device->rank().GlobalRank()] = group_rank; + auto device = Device(Device::DeviceType::kCUDA, i); + global_group_rank_map_[device.Rank().GlobalRank()] = group_rank; devices_.push_back(device); device_comm_map_[device] = comm; } @@ -179,7 +181,8 @@ void ProcessGroupNCCL::InitStreams() { comm_streams_.resize(device_size); for (int i = 0; i < device_size; ++i) { - devices_[i]->SetDevice(); + core::DeviceGuard guard(devices_[i]); + int low, high; CUDA_CHECK(cudaDeviceGetStreamPriorityRange(&low, &high)); CUDA_CHECK(cudaStreamCreateWithPriority(&comm_streams_[i], cudaStreamNonBlocking, high)); @@ -190,12 +193,14 @@ void ProcessGroupNCCL::InitStreams() { std::shared_ptr ProcessGroupNCCL::AllReduce(const std::shared_ptr &tensor, function::ReduceOpType reduce_op, bool async_op) const { void *buffer = tensor->DataPtr(); - const auto *device = dynamic_cast(tensor->GetDevice()); - device->SetDevice(); + auto device = tensor->GetDevice(); + core::DeviceGuard guard(device); auto comm = device_comm_map_.at(device); - cudaStream_t compute_stream = device->Stream(); + cudaStream_t compute_stream = dynamic_cast( + core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); cudaStream_t comm_stream = device_stream_map_.at(device); auto work = std::make_shared(device, comm); @@ -222,12 +227,14 @@ std::shared_ptr ProcessGroupNCCL::AllReduce(const std::shared_ptr std::shared_ptr ProcessGroupNCCL::AllGather(const std::shared_ptr &output, const std::shared_ptr &input, bool async_op) const { - const auto *device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); auto comm = device_comm_map_.at(device); - device->SetDevice(); + core::DeviceGuard guard(device); - cudaStream_t compute_stream = device->Stream(); + cudaStream_t compute_stream = dynamic_cast( + core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); cudaStream_t comm_stream = device_stream_map_.at(device); auto work = std::make_shared(device, comm); @@ -254,12 +261,14 @@ std::shared_ptr ProcessGroupNCCL::AllGather(const std::shared_ptr std::shared_ptr ProcessGroupNCCL::ReduceScatter(const std::shared_ptr &output, const std::shared_ptr &input, function::ReduceOpType reduce_op, bool async_op) const { - const auto *device = dynamic_cast(input->GetDevice()); + auto device = input->GetDevice(); auto comm = device_comm_map_.at(device); - device->SetDevice(); + core::DeviceGuard guard(device); - cudaStream_t compute_stream = device->Stream(); + cudaStream_t compute_stream = dynamic_cast( + core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); cudaStream_t comm_stream = device_stream_map_.at(device); auto work = std::make_shared(device, comm); @@ -286,12 +295,14 @@ std::shared_ptr ProcessGroupNCCL::ReduceScatter(const std::shared_ptr ProcessGroupNCCL::Send(std::vector> tensors, int dest_rank, bool async_op) const { CHECK_GT(tensors.size(), 0); - const auto *device = dynamic_cast(tensors[0]->GetDevice()); + auto device = tensors[0]->GetDevice(); auto comm = device_comm_map_.at(device); - device->SetDevice(); + core::DeviceGuard guard(device); - cudaStream_t compute_stream = device->Stream(); + cudaStream_t compute_stream = dynamic_cast( + core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); cudaStream_t comm_stream = device_stream_map_.at(device); auto work = std::make_shared(device, comm); @@ -332,12 +343,14 @@ std::shared_ptr ProcessGroupNCCL::Send(std::vector std::shared_ptr ProcessGroupNCCL::Recv(std::vector> tensors, int src_rank, bool async_op) const { CHECK_GT(tensors.size(), 0); - const auto *device = dynamic_cast(tensors[0]->GetDevice()); + auto device = tensors[0]->GetDevice(); auto comm = device_comm_map_.at(device); - device->SetDevice(); + core::DeviceGuard guard(device); - cudaStream_t compute_stream = device->Stream(); + cudaStream_t compute_stream = dynamic_cast( + core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); cudaStream_t comm_stream = device_stream_map_.at(device); auto work = std::make_shared(device, comm); @@ -380,7 +393,7 @@ ProcessGroupNCCL::BroadCast(const std::vector> &input_te std::vector> outputs; std::vector streams; std::vector comms; - std::vector devices; + std::vector devices; CHECK_EQ(world_size_, comms_.size()); @@ -390,7 +403,9 @@ ProcessGroupNCCL::BroadCast(const std::vector> &input_te outputs.push_back(std::make_shared(input_tensor->Dims(), input_tensor->Dtype(), device)); } devices.push_back(device); - streams.push_back(dynamic_cast(device)->Stream()); + streams.push_back(dynamic_cast( + core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream()); comms.push_back(device_comm_map_.at(device)); } @@ -405,7 +420,8 @@ ProcessGroupNCCL::BroadCast(const std::vector> &input_te NCCL_CHECK(ncclGroupStart()); for (size_t i = 0; i < devices.size(); ++i) { - devices[i]->SetDevice(); + core::DeviceGuard guard(devices[i]); + for (size_t j = 0; j < input_tensors.size(); ++j) { const auto &input_tensor = input_tensors[j]; const auto dtype = input_tensor->Dtype(); @@ -423,12 +439,12 @@ ProcessGroupNCCL::BroadCast(const std::vector> &input_te std::vector> ProcessGroupNCCL::ReduceAddCoalesced(const std::vector>> &grads, - const Device *destination) const { + Device destination) const { // grads: [devices, tensors] std::vector> outputs; std::vector streams; std::vector comms; - std::vector devices; + std::vector devices; for (size_t i = 0; i < grads[0].size(); ++i) { outputs.push_back(std::make_shared(grads[0][i]->Dims(), grads[0][i]->Dtype(), destination)); @@ -436,7 +452,9 @@ ProcessGroupNCCL::ReduceAddCoalesced(const std::vectorGetDevice()); - streams.push_back(dynamic_cast(devices[i])->Stream()); + streams.push_back(dynamic_cast( + core::GetDeviceGuardImpl(devices[i].type())->GetStream(devices[i])) + ->cuda_stream()); comms.push_back(device_comm_map_.at(devices[i])); } @@ -451,7 +469,8 @@ ProcessGroupNCCL::ReduceAddCoalesced(const std::vectorSetDevice(); + core::DeviceGuard guard(devices[i]); + for (size_t j = 0; j < grads[i].size(); ++j) { const auto &grad = grads[i][j]; const auto dtype = grad->Dtype(); @@ -468,7 +487,7 @@ ProcessGroupNCCL::ReduceAddCoalesced(const std::vector> ProcessGroupNCCL::Scatter(const std::shared_ptr &tensor, - std::vector devices, int64_t dim) const { + std::vector devices, int64_t dim) const { std::vector> outputs; std::vector> split_tensors = tensor->Split(tensor->Dims()[dim] / devices.size(), dim); std::vector streams; @@ -479,7 +498,9 @@ std::vector> ProcessGroupNCCL::Scatter(const std::shared src_rank = i; } outputs.push_back(std::make_shared(split_tensors[i]->Dims(), split_tensors[i]->Dtype(), devices[i])); - streams.push_back(dynamic_cast(devices[i])->Stream()); + streams.push_back(dynamic_cast( + core::GetDeviceGuardImpl(devices[i].type())->GetStream(devices[i])) + ->cuda_stream()); comms.push_back(device_comm_map_.at(devices[i])); } @@ -490,7 +511,8 @@ std::vector> ProcessGroupNCCL::Scatter(const std::shared auto nccl_dtype = kNcclDtypeMap.at(dtype); for (size_t i = 0; i < devices.size(); ++i) { - devices[i]->SetDevice(); + core::DeviceGuard guard(devices[i]); + const auto dtype = tensor->Dtype(); auto nccl_dtype = kNcclDtypeMap.at(dtype); NCCL_CHECK(ncclSend(split_tensors[i]->DataPtr(), split_tensors[i]->NumElements(), nccl_dtype, i, @@ -503,7 +525,7 @@ std::vector> ProcessGroupNCCL::Scatter(const std::shared } std::shared_ptr ProcessGroupNCCL::Gather(const std::vector> &tensors, - const Device *destination, int64_t dim) const { + Device destination, int64_t dim) const { std::vector> outouts; int64_t num_devices = tensors.size(); auto dtype = tensors[0]->Dtype(); @@ -513,7 +535,7 @@ std::shared_ptr ProcessGroupNCCL::Gather(const std::vector streams; std::vector comms; - std::vector devices; + std::vector devices; int dest_rank = -1; for (size_t i = 0; i < tensors.size(); ++i) { @@ -521,7 +543,9 @@ std::shared_ptr ProcessGroupNCCL::Gather(const std::vector(device)->Stream()); + streams.push_back(dynamic_cast( + core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream()); comms.push_back(device_comm_map_.at(device)); devices.push_back(device); @@ -538,7 +562,8 @@ std::shared_ptr ProcessGroupNCCL::Gather(const std::vectorSetDevice(); + core::DeviceGuard guard(devices[i]); + auto &tensor = tensors[i]; size_t num_elements = tensor->NumElements(); void *send_ptr = tensor->DataPtr(); diff --git a/infini_train/src/nn/parallel/reducer.cc b/infini_train/src/nn/parallel/reducer.cc index 0cdd7703..16fd9060 100644 --- a/infini_train/src/nn/parallel/reducer.cc +++ b/infini_train/src/nn/parallel/reducer.cc @@ -12,9 +12,11 @@ #include "glog/logging.h" #include "infini_train/include/autograd/function_hook.h" +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/nn/parallel/utils.h" #include "infini_train/include/nn/parallel/work.h" #include "infini_train/include/tensor.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::nn::parallel { namespace { @@ -26,17 +28,21 @@ void CopyGradToBucket(const std::shared_ptr &grad, const std::shared_ptr char *dst = static_cast(flat->DataPtr()) + dst_elem_offset * element_size_in_bytes; const void *src = grad->DataPtr(); - const auto dev_type = grad->GetDevice()->Type(); - if (dev_type == DeviceType::kCPU) { + const auto dev_type = grad->GetDevice().type(); + if (dev_type == Device::DeviceType::kCPU) { std::memcpy(dst, src, bytes); return; } #ifdef USE_CUDA - if (dev_type == DeviceType::kCUDA) { - auto *cuda_dev = dynamic_cast(flat->GetDevice()); - CHECK(cuda_dev); - cuda_dev->SetDevice(); - auto comm_stream = stream ? reinterpret_cast(stream) : cuda_dev->Stream(); + if (dev_type == Device::DeviceType::kCUDA) { + auto cuda_dev = flat->GetDevice(); + + core::DeviceGuard guard(cuda_dev); + + auto comm_stream = stream ? reinterpret_cast(stream) + : dynamic_cast( + core::GetDeviceGuardImpl(cuda_dev.type())->GetStream(cuda_dev)) + ->cuda_stream(); cudaMemcpyAsync(dst, src, bytes, cudaMemcpyDeviceToDevice, comm_stream); return; } @@ -52,17 +58,21 @@ void CopyBucketToGrad(const std::shared_ptr &flat, const std::shared_ptr const char *src = static_cast(flat->DataPtr()) + src_elem_offset * element_size_in_bytes; void *dst = grad->DataPtr(); - const auto dev_type = grad->GetDevice()->Type(); - if (dev_type == DeviceType::kCPU) { + const auto dev_type = grad->GetDevice().type(); + if (dev_type == Device::DeviceType::kCPU) { std::memcpy(dst, src, bytes); return; } #ifdef USE_CUDA - if (dev_type == DeviceType::kCUDA) { - auto *cuda_dev = dynamic_cast(flat->GetDevice()); - CHECK(cuda_dev); - cuda_dev->SetDevice(); - auto comm_stream = stream ? reinterpret_cast(stream) : cuda_dev->Stream(); + if (dev_type == Device::DeviceType::kCUDA) { + auto cuda_dev = flat->GetDevice(); + + core::DeviceGuard guard(cuda_dev); + + auto comm_stream = stream ? reinterpret_cast(stream) + : dynamic_cast( + core::GetDeviceGuardImpl(cuda_dev.type())->GetStream(cuda_dev)) + ->cuda_stream(); cudaMemcpyAsync(dst, src, bytes, cudaMemcpyDeviceToDevice, comm_stream); return; } @@ -135,7 +145,7 @@ std::vector> ComputeBucketAssignmentBySize(const std::vector const auto &tensor = tensors[idx_in_order]; CHECK(tensor); - const Key k = Key{tensors[idx_in_order]->GetDevice()->Index(), tensors[idx_in_order]->Dtype()}; + const Key k = Key{tensors[idx_in_order]->GetDevice().index(), tensors[idx_in_order]->Dtype()}; auto it = states.find(k); if (it == states.end()) { it = states.emplace(k, State{}).first; diff --git a/infini_train/src/nn/parallel/tensor_parallel.cc b/infini_train/src/nn/parallel/tensor_parallel.cc index b91028df..27a8a67f 100644 --- a/infini_train/src/nn/parallel/tensor_parallel.cc +++ b/infini_train/src/nn/parallel/tensor_parallel.cc @@ -15,7 +15,6 @@ #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/parallel_functional.h" #include "infini_train/include/nn/parallel/utils.h" -#include "infini_train/include/nn/parallel/work.h" #include "infini_train/include/tensor.h" namespace infini_train::nn::parallel { @@ -534,7 +533,7 @@ VocabParallelCrossEntropy::Backward(const std::vector> & auto masked_target = saved_tensors_[2]; auto valid_mask_local = saved_tensors_[3]; - auto device = grad_output->GetDevice()->Type(); + auto device = grad_output->GetDevice().type(); auto grad_input = Dispatcher::Instance().Call>( {device, "VocabParallelCrossEntropyBackward"}, grad_output, softmax_local, target_mask, masked_target, valid_mask_local, vocab_size_local_, vocab_size_original_, label_smoothing_); diff --git a/infini_train/src/nn/parallel/work.cc b/infini_train/src/nn/parallel/work.cc index 53fd465a..57018258 100644 --- a/infini_train/src/nn/parallel/work.cc +++ b/infini_train/src/nn/parallel/work.cc @@ -5,7 +5,9 @@ #ifdef USE_CUDA #include "infini_train/include/common/cuda/common_cuda.h" #endif +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/device.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train::nn::parallel { #ifdef USE_NCCL @@ -15,7 +17,7 @@ std::exception_ptr makeCudaError(cudaError_t err) { } } // namespace -WorkNccl::WorkNccl(const Device *device, ncclComm_t comm) : device_(device), comm_(comm) { +WorkNccl::WorkNccl(Device device, ncclComm_t comm) : device_(device), comm_(comm) { CUDA_CHECK(cudaEventCreateWithFlags(&ready_event_, cudaEventDisableTiming)); CUDA_CHECK(cudaEventCreateWithFlags(&done_event_, cudaEventDisableTiming)); } @@ -31,7 +33,7 @@ WorkNccl::~WorkNccl() { bool WorkNccl::WaitBlocking(std::chrono::milliseconds timeout) { // Block wait on host - device_->SetDevice(); + core::DeviceGuard guard(device_); // If timeout is not set, then wait till it finishes if (timeout <= std::chrono::milliseconds::zero()) { @@ -68,8 +70,11 @@ bool WorkNccl::WaitBlocking(std::chrono::milliseconds timeout) { bool WorkNccl::WaitNonBlocking() { // Non-blocking wait on compute stream - device_->SetDevice(); - CUDA_CHECK(cudaStreamWaitEvent(dynamic_cast(device_)->Stream(), done_event_, 0)); + core::DeviceGuard guard(device_); + CUDA_CHECK(cudaStreamWaitEvent(dynamic_cast( + core::GetDeviceGuardImpl(device_.type())->GetStream(device_)) + ->cuda_stream(), + done_event_, 0)); return true; } diff --git a/infini_train/src/optimizer.cc b/infini_train/src/optimizer.cc index 80e3887f..8eacafa3 100644 --- a/infini_train/src/optimizer.cc +++ b/infini_train/src/optimizer.cc @@ -2,6 +2,7 @@ #include +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/device.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" @@ -25,8 +26,8 @@ void SGD::Step() { continue; } auto device = param->GetDevice(); - device->SetDevice(); - auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "AccumulateGrad"}); + core::DeviceGuard guard(device); + auto kernel = Dispatcher::Instance().GetKernel({device.type(), "AccumulateGrad"}); kernel.Call(param->grad(), -learning_rate_, param); } } @@ -61,8 +62,8 @@ void Adam::Step() { auto &v = v_[i]; auto device = param->GetDevice(); - device->SetDevice(); - auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "AdamAccumulateGrad"}); + core::DeviceGuard guard(device); + auto kernel = Dispatcher::Instance().GetKernel({device.type(), "AdamAccumulateGrad"}); kernel.Call(grad, param, m, v, learning_rate_, beta1_, beta2_, eps_, t_); } } diff --git a/infini_train/src/profiler.cc b/infini_train/src/profiler.cc index f2be2f4b..01bf68aa 100644 --- a/infini_train/src/profiler.cc +++ b/infini_train/src/profiler.cc @@ -14,7 +14,9 @@ #ifdef USE_CUDA #include "infini_train/include/common/cuda/common_cuda.h" #endif +#include "infini_train/include/core/device_guard.h" #include "infini_train/include/device.h" +#include "infini_train/src/core/cuda/cuda_stream.h" namespace infini_train { namespace { @@ -38,8 +40,8 @@ Profiler &Profiler::Instance() { return profiler; } -int GetRank(DeviceType device) { - if (device == DeviceType::kCPU) { +int GetRank(Device::DeviceType device) { + if (device == Device::DeviceType::kCPU) { return 0; } @@ -53,25 +55,26 @@ int GetRank(DeviceType device) { #ifdef USE_CUDA cudaStream_t GetCudaStream() { - int device_id = GetRank(DeviceType::kCUDA); + int device_id = GetRank(Device::DeviceType::kCUDA); // TODO(zbl): support multi-stream on single device - return dynamic_cast( - DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, static_cast(device_id))) - ->Stream(); + auto device = Device(Device::DeviceType::kCUDA, static_cast(device_id)); + return dynamic_cast( + core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); } #endif -void Profiler::StartRecord(const std::string &name, DeviceType device) { +void Profiler::StartRecord(const std::string &name, Device::DeviceType device) { if (g_profiling_depth++ > 0) { return; } cpu_timing_map_[name] = std::chrono::high_resolution_clock::now(); switch (device) { - case DeviceType::kCPU: + case Device::DeviceType::kCPU: break; #ifdef USE_CUDA - case DeviceType::kCUDA: { + case Device::DeviceType::kCUDA: { auto it = cuda_timing_map_.find(name); if (it != cuda_timing_map_.end()) { // Make sure there are no conflicts @@ -100,7 +103,7 @@ void Profiler::StartRecord(const std::string &name, DeviceType device) { } } -void Profiler::EndRecord(const std::string &name, DeviceType device) { +void Profiler::EndRecord(const std::string &name, Device::DeviceType device) { if (--g_profiling_depth > 0) { return; } @@ -110,10 +113,10 @@ void Profiler::EndRecord(const std::string &name, DeviceType device) { int rank = GetRank(device); switch (device) { - case DeviceType::kCPU: + case Device::DeviceType::kCPU: break; #ifdef USE_CUDA - case DeviceType::kCUDA: { + case Device::DeviceType::kCUDA: { auto it = cuda_timing_map_.find(name); if (it != cuda_timing_map_.end()) { auto event_pair = it->second; diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index c5de11ce..2d27478e 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -1,5 +1,4 @@ #include "infini_train/include/tensor.h" -#include "infini_train/include/datatype.h" #include #include @@ -7,16 +6,9 @@ #include #include -#ifdef USE_CUDA -#include -#endif - #include "Eigen/Dense" #include "glog/logging.h" -#ifdef USE_CUDA -#include "infini_train/include/common/cuda/common_cuda.h" -#endif #include "infini_train/include/autograd/accumulate.h" #include "infini_train/include/autograd/elementwise.h" #include "infini_train/include/autograd/function.h" @@ -26,61 +18,35 @@ #include "infini_train/include/autograd/outer.h" #include "infini_train/include/autograd/reduction.h" #include "infini_train/include/autograd/transform.h" +#include "infini_train/include/core/device_guard.h" +#include "infini_train/include/datatype.h" #include "infini_train/include/device.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/nn/init.h" namespace infini_train { -TensorBuffer::TensorBuffer(const Device *device, size_t size) : device_(device), size_(size) { - CHECK_NOTNULL(device); - switch (device_->Type()) { - case DeviceType::kCPU: - data_ = malloc(size); - break; -#ifdef USE_CUDA - case DeviceType::kCUDA: { - int current_device = -1; - CUDA_CHECK(cudaGetDevice(¤t_device)); - // TODO(dcj): Maybe pin memory later. - device->SetDevice(); - const auto *cuda_device = dynamic_cast(device); - CUDA_CHECK(cudaMallocAsync(&data_, size, cuda_device->Stream())); - CUDA_CHECK(cudaSetDevice(current_device)); - break; - } -#endif - default: - LOG(FATAL) << "Unsupported device type: " << static_cast(device_->Type()); - break; - } +TensorBuffer::TensorBuffer(Device device, size_t size) : device_(device), size_(size) { + core::DeviceGuard guard(device); + auto *impl = core::GetDeviceGuardImpl(device.type()); + impl->MallocAsync(&data_, size, impl->GetStream(device)); } TensorBuffer::~TensorBuffer() { - switch (device_->Type()) { - case DeviceType::kCPU: - free(data_); - break; -#ifdef USE_CUDA - case DeviceType::kCUDA: - CUDA_CHECK(cudaFreeAsync(data_, dynamic_cast(device_)->Stream())); - break; -#endif - default: - LOG(FATAL) << "Unsupported device type: " << static_cast(device_->Type()); - break; - } + core::DeviceGuard guard(device_); + auto *impl = core::GetDeviceGuardImpl(device_.type()); + impl->FreeAsync(data_, impl->GetStream(device_)); } void *TensorBuffer::DataPtr() { return data_; } const void *TensorBuffer::DataPtr() const { return data_; } -const Device *TensorBuffer::GetDevice() const { return device_; } +Device TensorBuffer::GetDevice() const { return device_; } size_t TensorBuffer::Size() const { return size_; } // Tensor implementation -Tensor::Tensor(const std::vector &dims, DataType dtype, const Device *device) : dims_(dims), dtype_(dtype) { +Tensor::Tensor(const std::vector &dims, DataType dtype, Device device) : dims_(dims), dtype_(dtype) { num_elements_ = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies()); buffer_ = std::make_shared(device, kDataTypeToSize.at(dtype) * num_elements_); } @@ -91,29 +57,22 @@ Tensor::Tensor(const Tensor &tensor, size_t offset, const std::vector & CHECK_LE(offset_ + kDataTypeToSize.at(dtype_) * num_elements_, buffer_->Size()); } -Tensor::Tensor(const float *data, const std::vector &dims, DataType dtype, const Device *device) +Tensor::Tensor(const float *data, const std::vector &dims, DataType dtype, Device device) : dims_(dims), dtype_(dtype), num_elements_(std::accumulate(dims.begin(), dims.end(), 1, std::multiplies())) { // TODO(dcj): support more datatype CHECK(dtype == DataType::kFLOAT32); buffer_ = std::make_shared(device, kDataTypeToSize.at(dtype) * num_elements_); - switch (device->Type()) { - case DeviceType::kCPU: - memcpy(buffer_->DataPtr(), data, buffer_->Size()); - break; -#ifdef USE_CUDA - case DeviceType::kCUDA: - CUDA_CHECK(cudaMemcpyAsync(buffer_->DataPtr(), data, buffer_->Size(), cudaMemcpyHostToDevice, - dynamic_cast(device)->Stream())); - break; -#endif - default: - LOG(FATAL) << "Unsupported device type: " << static_cast(device->Type()); - } + + core::DeviceGuard guard(device); + auto *impl = core::GetDeviceGuardImpl(device.type()); + impl->MemcpyAsync(buffer_->DataPtr(), data, buffer_->Size(), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, + impl->GetStream(device)); } -const Device *Tensor::GetDevice() const { return buffer_->GetDevice(); } +Device Tensor::GetDevice() const { return buffer_->GetDevice(); } void *Tensor::DataPtr() { return reinterpret_cast(buffer_->DataPtr()) + offset_; } @@ -127,38 +86,6 @@ size_t Tensor::NumElements() const { return num_elements_; } DataType Tensor::Dtype() const { return dtype_; } -template void Tensor::Fill(T value) { - auto device = GetDevice(); - device->SetDevice(); - - DataType dtype = Dtype(); - - uint64_t storage = 0; - - DispatchFunc(Dtype(), [&storage, value]() { - TargetT casted_value = static_cast(value); - std::memcpy((void *)(&storage), &casted_value, sizeof(TargetT)); - }); - - auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "Fill"}); - kernel.Call(shared_from_this(), static_cast(&storage)); -} - -template void Tensor::Fill(uint8_t); -template void Tensor::Fill(int8_t); -template void Tensor::Fill(uint16_t); -template void Tensor::Fill(int16_t); -template void Tensor::Fill(uint32_t); -template void Tensor::Fill(int32_t); -template void Tensor::Fill(uint64_t); -template void Tensor::Fill(int64_t); -template void Tensor::Fill(float); -template void Tensor::Fill(double); -#ifdef USE_CUDA -template void Tensor::Fill(nv_bfloat16); -template void Tensor::Fill(half); -#endif - Eigen::Map> Tensor::EigenMatrix() { const int64_t bs = std::accumulate(dims_.rbegin() + 1, dims_.rend(), 1, std::multiplies()); return Eigen::Map>( @@ -171,8 +98,9 @@ Eigen::Map> Tensor::Eig dims_[0]); } -Tensor Tensor::To(const Device *device) { - if (device == buffer_->GetDevice()) { +Tensor Tensor::To(Device device) { + const auto buffer_device = buffer_->GetDevice(); + if (device == buffer_device) { auto new_tensor = Tensor(*this, offset_, dims_); if (grad_) { new_tensor.grad_ = std::make_unique(*grad_.get(), grad_->offset_, grad_->dims_); @@ -181,39 +109,31 @@ Tensor Tensor::To(const Device *device) { } Tensor new_tensor; - switch (device->Type()) { -#ifdef USE_CUDA - case DeviceType::kCPU: { - // CUDA -> CPU - GetDevice()->SetDevice(); - new_tensor = Tensor(dims_, dtype_, DeviceManager::Instance()->GetDefaultDevice()); - CUDA_CHECK(cudaMemcpy(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), cudaMemcpyDeviceToHost)); - break; - } - case DeviceType::kCUDA: { - int current_device = -1; - CUDA_CHECK(cudaGetDevice(¤t_device)); + if (device.type() == Device::DeviceType::kCPU) { + // D2H + new_tensor = Tensor(dims_, dtype_, Device()); + core::DeviceGuard guard(buffer_device); + auto impl = core::GetDeviceGuardImpl(buffer_device.type()); + impl->MemcpyAsync(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), core::MemcpyKind::kD2H, + impl->GetStream(buffer_device)); + + } else if (buffer_device.type() == Device::DeviceType::kCPU) { new_tensor = Tensor(dims_, dtype_, device); - if (GetDevice()->Type() == DeviceType::kCPU) { - device->SetDevice(); - // CPU -> CUDA - CUDA_CHECK(cudaMemcpyAsync(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), cudaMemcpyHostToDevice, - dynamic_cast(device)->Stream())); - } else { - // CUDA -> CUDA - // 1. CUDA -> CPU - // 2. CPU -> CUDA - Tensor cpu_tensor = To(DeviceManager::Instance()->GetDefaultDevice()); - device->SetDevice(); - CUDA_CHECK(cudaMemcpyAsync(new_tensor.DataPtr(), cpu_tensor.DataPtr(), SizeInBytes(), - cudaMemcpyHostToDevice, dynamic_cast(device)->Stream())); - } - CUDA_CHECK(cudaSetDevice(current_device)); - break; - } -#endif - default: - LOG(FATAL) << "Unsupported device type: " << static_cast(device->Type()); + // H2D + core::DeviceGuard guard(device); + auto *impl = core::GetDeviceGuardImpl(device.type()); + impl->MemcpyAsync(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), core::MemcpyKind::kH2D, + impl->GetStream(device)); + } else { + new_tensor = Tensor(dims_, dtype_, device); + // P2P + // 1. D2H + Tensor cpu_tensor = To(Device()); + // 2. H2D + core::DeviceGuard guard(buffer_device); + auto *impl = core::GetDeviceGuardImpl(buffer_device.type()); + impl->MemcpyAsync(new_tensor.DataPtr(), cpu_tensor.DataPtr(), SizeInBytes(), core::MemcpyKind::kH2D, + impl->GetStream(buffer_device)); } if (grad_) { @@ -235,9 +155,9 @@ Tensor Tensor::To(DataType dtype) { } auto device = GetDevice(); - device->SetDevice(); + core::DeviceGuard guard(device); - auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "Cast"}); + auto kernel = Dispatcher::Instance().GetKernel({device.type(), "Cast"}); auto new_tensor = *kernel.Call>(shared_from_this(), dtype); if (grad_) { @@ -256,71 +176,33 @@ void Tensor::CopyFrom(const Tensor &src) { CHECK(Dims() == src.Dims()) << "Tensor::CopyFrom shape mismatch"; const size_t nbytes = SizeInBytes(); - const Device *dst_dev = GetDevice(); - const Device *src_dev = src.GetDevice(); - - switch (dst_dev->Type()) { - case DeviceType::kCPU: { - switch (src_dev->Type()) { - case DeviceType::kCPU: { - std::memcpy(DataPtr(), src.DataPtr(), nbytes); - break; - } -#ifdef USE_CUDA - case DeviceType::kCUDA: { - // CUDA -> CPU - CUDA_CHECK(cudaMemcpy(DataPtr(), src.DataPtr(), nbytes, cudaMemcpyDeviceToHost)); - break; - } -#endif - default: - LOG(FATAL) << "Unsupported src device type: " << static_cast(src_dev->Type()); - } - break; - } - -#ifdef USE_CUDA - case DeviceType::kCUDA: { - int current_device = -1; - CUDA_CHECK(cudaGetDevice(¤t_device)); - dst_dev->SetDevice(); - - const auto *dst_cuda = dynamic_cast(dst_dev); - switch (src_dev->Type()) { - case DeviceType::kCPU: { - // CPU -> CUDA - CUDA_CHECK(cudaMemcpyAsync(DataPtr(), src.DataPtr(), nbytes, cudaMemcpyHostToDevice, dst_cuda->Stream())); - break; - } - case DeviceType::kCUDA: { - const auto *src_cuda = dynamic_cast(src_dev); - if (src_cuda->Index() == dst_cuda->Index()) { - CUDA_CHECK( - cudaMemcpyAsync(DataPtr(), src.DataPtr(), nbytes, cudaMemcpyDeviceToDevice, dst_cuda->Stream())); - } else { - int canAccessPeer = 0; - CUDA_CHECK(cudaDeviceCanAccessPeer(&canAccessPeer, dst_cuda->Index(), src_cuda->Index())); - if (canAccessPeer) { - CUDA_CHECK(cudaMemcpyPeerAsync(DataPtr(), dst_cuda->Index(), src.DataPtr(), src_cuda->Index(), - nbytes, dst_cuda->Stream())); - } else { - LOG(FATAL) << "Check accessibility between Device " << src_cuda->Index() << " and Device " - << dst_cuda->Index(); - } - } - break; - } - default: - LOG(FATAL) << "Unsupported src device type: " << static_cast(src_dev->Type()); - } - - CUDA_CHECK(cudaSetDevice(current_device)); - break; - } -#endif - - default: - LOG(FATAL) << "Unsupported dst device type: " << static_cast(dst_dev->Type()); + const Device dst_dev = GetDevice(); + const Device src_dev = src.GetDevice(); + + if (dst_dev == src_dev) { + core::DeviceGuard guard(dst_dev); + auto *impl = core::GetDeviceGuardImpl(dst_dev.type()); + impl->MemcpyAsync(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kD2D, impl->GetStream(dst_dev)); + } else if (dst_dev.type() == Device::DeviceType::kCPU) { + // D2H + core::DeviceGuard guard(src_dev); + auto *impl = core::GetDeviceGuardImpl(src_dev.type()); + impl->MemcpyAsync(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kD2H, impl->GetStream(src_dev)); + } else if (src_dev.type() == Device::DeviceType::kCPU) { + // H2D + core::DeviceGuard guard(dst_dev); + auto *impl = core::GetDeviceGuardImpl(dst_dev.type()); + impl->MemcpyAsync(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kH2D, impl->GetStream(dst_dev)); + } else { + // TODO(dcj): maybe support p2p api later + // P2P + // 1. D2H + Tensor cpu_tensor(dims_, dtype_, Device()); + cpu_tensor.CopyFrom(src); + // 2. H2D + core::DeviceGuard guard(dst_dev); + auto *impl = core::GetDeviceGuardImpl(dst_dev.type()); + impl->MemcpyAsync(DataPtr(), cpu_tensor.DataPtr(), nbytes, core::MemcpyKind::kH2D, impl->GetStream(dst_dev)); } } @@ -376,7 +258,7 @@ std::shared_ptr Tensor::Or(const std::shared_ptr &other) { } std::shared_ptr Tensor::Add(const std::shared_ptr &other) { - CHECK_EQ(static_cast(GetDevice()->Type()), static_cast(other->GetDevice()->Type())); + CHECK_EQ(static_cast(GetDevice().type()), static_cast(other->GetDevice().type())); return std::make_shared()->Apply({shared_from_this(), other})[0]; } @@ -385,12 +267,12 @@ std::shared_ptr Tensor::Add(float scalar) { } std::shared_ptr Tensor::Sub(const std::shared_ptr &other) { - CHECK_EQ(static_cast(GetDevice()->Type()), static_cast(other->GetDevice()->Type())); + CHECK_EQ(static_cast(GetDevice().type()), static_cast(other->GetDevice().type())); return std::make_shared()->Apply({shared_from_this(), other})[0]; } std::shared_ptr Tensor::Mul(const std::shared_ptr &other) { - CHECK_EQ(static_cast(GetDevice()->Type()), static_cast(other->GetDevice()->Type())); + CHECK_EQ(static_cast(GetDevice().type()), static_cast(other->GetDevice().type())); return std::make_shared()->Apply({shared_from_this(), other})[0]; } @@ -399,7 +281,7 @@ std::shared_ptr Tensor::Mul(float scalar) { } std::shared_ptr Tensor::Div(const std::shared_ptr &other) { - CHECK_EQ(static_cast(GetDevice()->Type()), static_cast(other->GetDevice()->Type())); + CHECK_EQ(static_cast(GetDevice().type()), static_cast(other->GetDevice().type())); return std::make_shared()->Apply({shared_from_this(), other})[0]; } @@ -617,7 +499,7 @@ void Tensor::Backward(std::shared_ptr gradient, bool retain_graph, bool gradient = std::make_shared(std::vector{}, dtype_, GetDevice()); gradient->Fill(1.0f); } else { - CHECK_EQ(static_cast(GetDevice()->Type()), static_cast(gradient->GetDevice()->Type())); + CHECK_EQ(static_cast(GetDevice().type()), static_cast(gradient->GetDevice().type())); CHECK_EQ(static_cast(dtype_), static_cast(gradient->Dtype())); CHECK_EQ(dims_.size(), gradient->Dims().size()); for (int idx = 0; idx < dims_.size(); ++idx) { CHECK_EQ(dims_[idx], gradient->Dims()[idx]); } @@ -755,23 +637,14 @@ void Tensor::SaveAsNpy(const std::string &path) const { const size_t num_bytes = num_elements * sizeof(float); // Prepare host buffer - std::vector host_buffer(num_elements); + auto impl = core::GetDeviceGuardImpl(GetDevice().type()); - if (GetDevice()->Type() == DeviceType::kCPU) { - // If on CPU, direct copy - std::memcpy(host_buffer.data(), DataPtr(), num_bytes); - } -#ifdef USE_CUDA - else if (GetDevice()->Type() == DeviceType::kCUDA) { - // If on CUDA, copy back to host - cudaDeviceSynchronize(); - cudaError_t err = cudaMemcpy(host_buffer.data(), DataPtr(), num_bytes, cudaMemcpyDeviceToHost); - CHECK_EQ(err, cudaSuccess) << "cudaMemcpy failed: " << cudaGetErrorString(err); - } -#endif - else { - LOG(FATAL) << "Unsupported device type for SaveAsNpy."; - } + impl->SynchronizeDevice(GetDevice()); + + Tensor cpu_tensor(dims_, dtype_, Device()); + cpu_tensor.CopyFrom(*this); + + impl->SynchronizeDevice(GetDevice()); // Write .npy file std::ofstream file(path, std::ios::binary); @@ -813,7 +686,7 @@ void Tensor::SaveAsNpy(const std::string &path) const { file.write(header.c_str(), header.size()); // Write data - file.write(reinterpret_cast(host_buffer.data()), num_bytes); + file.write(reinterpret_cast(cpu_tensor.DataPtr()), num_bytes); file.close(); } @@ -876,21 +749,16 @@ void Tensor::Print(std::ostream &os) const { const size_t num_elements = NumElements(); const size_t num_bytes = num_elements * sizeof(float); - std::vector host_buffer(num_elements); + auto impl = core::GetDeviceGuardImpl(GetDevice().type()); - if (GetDevice()->Type() == DeviceType::kCPU) { - std::memcpy(host_buffer.data(), DataPtr(), num_bytes); - } -#ifdef USE_CUDA - else if (GetDevice()->Type() == DeviceType::kCUDA) { - cudaDeviceSynchronize(); - cudaError_t err = cudaMemcpy(host_buffer.data(), DataPtr(), num_bytes, cudaMemcpyDeviceToHost); - CHECK_EQ(err, cudaSuccess) << "cudaMemcpy failed: " << cudaGetErrorString(err); - } -#endif - else { - LOG(FATAL) << "Unsupported device type for Print."; - } + impl->SynchronizeDevice(GetDevice()); + + Tensor cpu_tensor(dims_, dtype_, Device()); + cpu_tensor.CopyFrom(*this); + + impl->SynchronizeDevice(GetDevice()); + + const float *buffer = static_cast(cpu_tensor.DataPtr()); const PrintOptions &opts = PrintOptions::Get(); const int64_t precision = opts.precision; @@ -901,7 +769,8 @@ void Tensor::Print(std::ostream &os) const { bool use_sci = opts.sci_mode.value_or(false); if (!opts.sci_mode.has_value()) { - for (float v : host_buffer) { + for (int idx = 0; idx < NumElements(); ++idx) { + const auto v = buffer[idx]; float abs_v = std::fabs(v); if ((abs_v > 0.0f && abs_v < 1e-4f) || abs_v >= 1e+4f) { use_sci = true; @@ -924,7 +793,7 @@ void Tensor::Print(std::ostream &os) const { std::vector str_vals(num_elements); size_t max_width = 0; for (size_t i = 0; i < num_elements; ++i) { - str_vals[i] = format_float(host_buffer[i]); + str_vals[i] = format_float(buffer[i]); max_width = std::max(max_width, str_vals[i].length()); }