Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions example/common/tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
#include "glog/logging.h"

#include "example/common/utils.h"
#include "infini_train/include/nn/functional.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/tensor.h"

namespace infini_train {

Expand Down Expand Up @@ -103,7 +106,7 @@ std::string Tokenizer::Decode(uint32_t token_id) const {
}

void Tokenizer::GenerateText(infini_train::nn::Module &model, uint32_t batch_size, uint32_t sequence_length,
uint32_t text_length, const Device *device) const {
uint32_t text_length, Device device) const {
std::vector<int64_t> dims;
dims.assign({batch_size, sequence_length});
// x_tensor (FLAGS_batch_size, FLAGS_sequence_length) eq:(4, 64)
Expand All @@ -121,7 +124,7 @@ void Tokenizer::GenerateText(infini_train::nn::Module &model, uint32_t batch_siz
uint64_t kRngState = kRngState;
LOG(INFO) << "start generate text:";

const auto *cpu_device = DeviceManager::Instance()->GetDefaultDevice();
auto cpu_device = Device();
for (int t = prompt_len; t < text_length; ++t) {
x = std::make_shared<infini_train::Tensor>(x->To(device)); // CPU->calc device
// TODO(jym): use no_grad forward later
Expand Down
10 changes: 4 additions & 6 deletions example/common/tokenizer.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
#include <cctype>
#include <cstdint>
#include <memory>
#include <vector>

#include "infini_train/include/device.h"
#include "infini_train/include/nn/functional.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/tensor.h"

namespace infini_train {

namespace nn {
class Module;
}
class Tokenizer {
public:
enum class Version : uint32_t {
Expand All @@ -22,7 +20,7 @@ class Tokenizer {
std::string Decode(uint32_t token_id) const;

void GenerateText(infini_train::nn::Module &model, uint32_t batch_size, uint32_t sequence_length,
uint32_t text_length, const Device *device) const;
uint32_t text_length, Device device) const;

uint32_t GetEndToken() const { return eot_token_; };

Expand Down
16 changes: 7 additions & 9 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ void Train(const nn::parallel::Rank &rank) {
using namespace nn::parallel;

// select the device
const Device *device;
Device device;

int ddp_world_size = global::GetDataParallelSize();
int tp_world_size = global::GetTensorParallelSize();
Expand All @@ -125,7 +125,7 @@ void Train(const nn::parallel::Rank &rank) {
const ProcessGroup *pp_pg = nullptr;

if (rank.IsParallel()) {
device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank());
device = Device(Device::DeviceType::kCUDA, rank.thread_rank());

if (ddp_world_size > 1) {
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
Expand All @@ -149,8 +149,7 @@ void Train(const nn::parallel::Rank &rank) {
nn::parallel::pp_rank = pp_rank;
}
} else {
device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice()
: DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0);
device = FLAGS_device == kDeviceCPU ? Device() : Device(Device::DeviceType::kCUDA, 0);
}

// calculate gradient accumulation from the desired total batch size and the current run configuration
Expand Down Expand Up @@ -201,7 +200,7 @@ void Train(const nn::parallel::Rank &rank) {

model = std::make_shared<nn::parallel::PipelineParallel>(
model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared<optimizers::SGD>(optimizer),
rank.thread_rank(), std::dynamic_pointer_cast<GPT2>(model)->GetChunkSize());
device, std::dynamic_pointer_cast<GPT2>(model)->GetChunkSize());
if (ddp_world_size > 1) {
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
Expand Down Expand Up @@ -285,7 +284,7 @@ void Train(const nn::parallel::Rank &rank) {

for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) {
// enable autocast for the current step
infini_train::AutocastGuard autocast_guard(device->Type(), dtype);
infini_train::AutocastGuard autocast_guard(device.type(), dtype);

// (bs, seq_len), (bs, seq_len)
auto [x, y] = *train_iter;
Expand All @@ -308,7 +307,7 @@ void Train(const nn::parallel::Rank &rank) {

LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish loss forward";

auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice());
auto loss_cpu = loss->To(Device());
lossf += static_cast<const float *>(loss_cpu.DataPtr())[0];
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start backward";
loss->Backward();
Expand All @@ -330,8 +329,7 @@ void Train(const nn::parallel::Rank &rank) {
if (ddp_world_size > 1) {
auto lossf_tensor = std::make_shared<Tensor>(&lossf, std::vector<int64_t>{}, DataType::kFLOAT32, device);
function::AllReduce(lossf_tensor, function::ReduceOpType::kAvg, ddp_pg);
lossf = static_cast<const float *>(
lossf_tensor->To(DeviceManager::Instance()->GetDefaultDevice()).DataPtr())[0];
lossf = static_cast<const float *>(lossf_tensor->To(Device()).DataPtr())[0];
}

const auto iter_end = std::chrono::high_resolution_clock::now();
Expand Down
16 changes: 7 additions & 9 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ void Train(const nn::parallel::Rank &rank) {
using namespace nn::parallel;

// select the device
const Device *device;
Device device;

int ddp_world_size = global::GetDataParallelSize();
int tp_world_size = global::GetTensorParallelSize();
Expand All @@ -107,7 +107,7 @@ void Train(const nn::parallel::Rank &rank) {
const ProcessGroup *pp_pg = nullptr;

if (rank.IsParallel()) {
device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank());
device = Device(Device::DeviceType::kCUDA, rank.thread_rank());

if (ddp_world_size > 1) {
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
Expand All @@ -131,8 +131,7 @@ void Train(const nn::parallel::Rank &rank) {
nn::parallel::pp_rank = pp_rank;
}
} else {
device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice()
: DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0);
device = FLAGS_device == kDeviceCPU ? Device() : Device(Device::DeviceType::kCUDA, 0);
}

// calculate gradient accumulation from the desired total batch size and the current run configuration
Expand Down Expand Up @@ -181,7 +180,7 @@ void Train(const nn::parallel::Rank &rank) {

model = std::make_shared<nn::parallel::PipelineParallel>(
model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared<optimizers::Adam>(optimizer),
rank.thread_rank(), std::dynamic_pointer_cast<LLaMA3>(model)->GetChunkSize());
device, std::dynamic_pointer_cast<LLaMA3>(model)->GetChunkSize());
if (ddp_world_size > 1) {
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
Expand Down Expand Up @@ -261,7 +260,7 @@ void Train(const nn::parallel::Rank &rank) {

for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) {
// enable autocast for the current step
infini_train::AutocastGuard autocast_guard(device->Type(), dtype);
infini_train::AutocastGuard autocast_guard(device.type(), dtype);

// (bs, seq_len), (bs, seq_len)
auto [x, y] = *train_iter;
Expand All @@ -284,7 +283,7 @@ void Train(const nn::parallel::Rank &rank) {

LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish loss forward";

auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice());
auto loss_cpu = loss->To(Device());
lossf += static_cast<const float *>(loss_cpu.DataPtr())[0];
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start backward";
loss->Backward();
Expand All @@ -306,8 +305,7 @@ void Train(const nn::parallel::Rank &rank) {
if (ddp_world_size > 1) {
auto lossf_tensor = std::make_shared<Tensor>(&lossf, std::vector<int64_t>{}, DataType::kFLOAT32, device);
function::AllReduce(lossf_tensor, function::ReduceOpType::kAvg, ddp_pg);
lossf = static_cast<const float *>(
lossf_tensor->To(DeviceManager::Instance()->GetDefaultDevice()).DataPtr())[0];
lossf = static_cast<const float *>(lossf_tensor->To(Device()).DataPtr())[0];
}

const auto iter_end = std::chrono::high_resolution_clock::now();
Expand Down
5 changes: 2 additions & 3 deletions example/llama3/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ ApplyRotaryEmbedding(const std::shared_ptr<Tensor> &xq, const std::shared_ptr<Te
}

std::shared_ptr<Tensor> PrecomputeFreqsCis(int64_t dim, int64_t end, float theta = 10000.0f, bool use_scaled = false,
const infini_train::Device *device
= DeviceManager::Instance()->GetDefaultDevice()) {
infini_train::Device device = Device()) {
DataType dtype = DataType::kFLOAT32;
CHECK_GE(dim, 2) << "dim must be >= 2 for slicing";
auto arange = nn::init::Arange(0, dim, dtype, device)->Slice(0, 0, dim, 2);
Expand All @@ -127,7 +126,7 @@ std::vector<std::shared_ptr<Tensor>> SwiGLU::Forward(const std::vector<std::shar
return {x[0] * nn::function::Sigmoid(x[0])};
}

RMSNorm::RMSNorm(int64_t dim, float eps, const infini_train::Device *device) : eps_(eps) {
RMSNorm::RMSNorm(int64_t dim, float eps, infini_train::Device device) : eps_(eps) {
parameters_[kParamWeightName]
= std::make_shared<Tensor>(std::vector<int64_t>{dim}, DataType::kFLOAT32, device)->RequiresGrad();
nn::init::Ones(parameters_[kParamWeightName]);
Expand Down
3 changes: 1 addition & 2 deletions example/llama3/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ class RMSNorm : public infini_train::nn::CloneableModule<RMSNorm> {
public:
static constexpr char kParamWeightName[] = "weight";

explicit RMSNorm(int64_t dim, float eps = 1e-6f,
const infini_train::Device *device = infini_train::DeviceManager::Instance()->GetDefaultDevice());
explicit RMSNorm(int64_t dim, float eps = 1e-6f, infini_train::Device device = infini_train::Device());

std::vector<std::shared_ptr<infini_train::Tensor>>
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
Expand Down
5 changes: 2 additions & 3 deletions example/mnist/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ int main(int argc, char *argv[]) {
DataLoader test_dataloader(test_dataset, FLAGS_bs);

auto network = MNIST();
const Device *device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDevice(DeviceType::kCPU, 0)
: DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0);
const Device *cpu_device = DeviceManager::Instance()->GetDefaultDevice();
Device device = FLAGS_device == kDeviceCPU ? Device() : Device(Device::DeviceType::kCUDA, 0);
Device cpu_device = Device();
network.To(device);

auto loss_fn = nn::CrossEntropyLoss();
Expand Down
27 changes: 11 additions & 16 deletions infini_train/include/autocast.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,10 @@
#include <string_view>
#include <unordered_map>

#include "common/common.h"
#include "datatype.h"
#include "device.h"
#include "tensor.h"

#ifdef USE_CUDA
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#endif
#include "infini_train/include/common/common.h"
#include "infini_train/include/datatype.h"
#include "infini_train/include/device.h"
#include "infini_train/include/tensor.h"

namespace infini_train {
namespace {
Expand Down Expand Up @@ -91,18 +86,18 @@ inline const std::unordered_map<std::string_view, CastPolicy> kOpCastPolicyMap =
};

// Default autocast data types for each device type
inline constexpr std::array<DataType, static_cast<size_t>(DeviceType::kCount)> kDeviceDefaultDtype = {
inline constexpr std::array<DataType, static_cast<size_t>(Device::DeviceType::kCount)> kDeviceDefaultDtype = {
DataType::kBFLOAT16, // CPU
DataType::kFLOAT16, // CUDA.
};

// Thread-local context to track autocast state
struct AutocastContext {
bool enabled = false; // Whether autocast is active in the current thread
DeviceType device_type = DeviceType::kCPU; // Target device type (CPU/GPU)
DataType autocast_dtype = DataType::kBFLOAT16; // The data type used for autocasting
bool enabled = false; // Whether autocast is active in the current thread
Device::DeviceType device_type = Device::DeviceType::kCPU; // Target device type (CPU/GPU)
DataType autocast_dtype = DataType::kBFLOAT16; // The data type used for autocasting

template <typename... ArgsT> void Autocast(std::pair<DeviceType, std::string> key, ArgsT &...args) {
template <typename... ArgsT> void Autocast(std::pair<Device::DeviceType, std::string> key, ArgsT &...args) {
if (!enabled) {
return;
}
Expand Down Expand Up @@ -172,14 +167,14 @@ inline thread_local AutocastContext tls_autocast_context;
// RAII guard to enable/disable autocast in a scope
class AutocastGuard {
public:
AutocastGuard(DeviceType device_type, DataType autocast_dtype) {
AutocastGuard(Device::DeviceType device_type, DataType autocast_dtype) {
saved_context_ = tls_autocast_context;
tls_autocast_context.enabled = true;
tls_autocast_context.device_type = device_type;
tls_autocast_context.autocast_dtype = autocast_dtype;
}

AutocastGuard(DeviceType device_type)
AutocastGuard(Device::DeviceType device_type)
: AutocastGuard(device_type, kDeviceDefaultDtype[static_cast<size_t>(device_type)]) {}

// Disable autocast (restore previous state)
Expand Down
26 changes: 13 additions & 13 deletions infini_train/include/autograd/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <vector>

#include "infini_train/include/autograd/function.h"
#include "infini_train/include/device.h"

namespace infini_train {
class Tensor;
Expand All @@ -19,7 +20,7 @@ class Scatter : public autograd::Function {
public:
static constexpr char kType[] = "ScatterFunction";

explicit Scatter(const std::vector<const Device *> &target_gpus, int64_t dim,
explicit Scatter(const std::vector<Device> &target_gpus, int64_t dim,
const infini_train::nn::parallel::ProcessGroup *pg = nullptr);

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
Expand All @@ -31,17 +32,16 @@ class Scatter : public autograd::Function {

private:
const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr;
std::vector<const Device *> target_gpus_;
const Device *input_device_ = nullptr;
std::vector<Device> target_gpus_;
Device input_device_ = Device();
int64_t dim_ = 0;
};

class Gather : public autograd::Function {
public:
static constexpr char kType[] = "GatherFunction";

explicit Gather(const Device *target_device, int64_t dim,
const infini_train::nn::parallel::ProcessGroup *pg = nullptr);
explicit Gather(Device target_device, int64_t dim, const infini_train::nn::parallel::ProcessGroup *pg = nullptr);

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;

Expand All @@ -52,8 +52,8 @@ class Gather : public autograd::Function {

private:
const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr;
const Device *target_device_ = nullptr;
std::vector<const Device *> input_gpus_;
Device target_device_ = Device();
std::vector<Device> input_gpus_;
int64_t dim_ = 0;
bool unsqueezed_scalar_ = false;
};
Expand All @@ -62,7 +62,7 @@ class Broadcast : public autograd::Function {
public:
static constexpr char kType[] = "BroadcastFunction";

explicit Broadcast(const std::vector<const Device *> &target_gpus,
explicit Broadcast(const std::vector<Device> &target_gpus,
const infini_train::nn::parallel::ProcessGroup *pg = nullptr);

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
Expand All @@ -74,16 +74,16 @@ class Broadcast : public autograd::Function {

private:
const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr;
std::vector<const Device *> target_gpus_;
std::vector<Device> target_gpus_;
int64_t num_inputs_ = 0;
const Device *input_device_ = nullptr;
Device input_device_ = Device();
};

class ReduceAddCoalesced : public autograd::Function {
public:
static constexpr char kType[] = "ReduceAddCoalescedFunction";

explicit ReduceAddCoalesced(const Device *destination, int64_t num_inputs,
explicit ReduceAddCoalesced(Device destination, int64_t num_inputs,
const infini_train::nn::parallel::ProcessGroup *pg = nullptr);

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
Expand All @@ -95,8 +95,8 @@ class ReduceAddCoalesced : public autograd::Function {

private:
const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr;
const Device *destination_ = nullptr;
std::vector<const Device *> target_gpus_;
Device destination_ = Device();
std::vector<Device> target_gpus_;
int64_t num_inputs_ = 0;
};
} // namespace infini_train::autograd
3 changes: 3 additions & 0 deletions infini_train/include/common/common.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#pragma once

#include <cstdint>
#include <vector>

#include "glog/logging.h"

#include "infini_train/include/datatype.h"
Expand Down
11 changes: 11 additions & 0 deletions infini_train/include/core/blas_handle.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

namespace infini_train::core {

class BlasHandle {
public:
BlasHandle(){};
virtual ~BlasHandle() = default;
};

} // namespace infini_train::core
Loading