Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,23 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float scale, cudaStream_t stream);

/*! \brief Check overflow and scale a list of tensors. scale is tensor input.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] scale Tensor for the scaling operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_scale_tensor_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor scale,
cudaStream_t stream);

/*! \brief Check overflow and scale a list of tensors.
*
* \warning This API is **experimental** and subject to change.
Expand Down
95 changes: 95 additions & 0 deletions transformer_engine/common/multi_tensor/scale.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,75 @@ struct ScaleFunctor {
}
};

template <typename in_t, typename out_t>
struct ScalePtrFunctor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem,
TensorListMetadata<2> &tl, // NOLINT(*)
float *scale_ptr) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
float scale = *scale_ptr;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];

in_t *in = reinterpret_cast<in_t *>(tl.addresses[0][tensor_loc]);
in += chunk_idx * chunk_size;

out_t *out = reinterpret_cast<out_t *>(tl.addresses[1][tensor_loc]);
out += chunk_idx * chunk_size;

n -= chunk_idx * chunk_size;

bool finite = true;
in_t r_in[ILP];
out_t r_out[ILP];

// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out)) {
for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(r_in, in, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(static_cast<float>(r_in[ii]));
}
// store
load_store(out, r_out, i_start, 0);
}
} else {
// Non-divergent exit condition for __syncthreads, not necessary here
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_in[ii] = 0.f;
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) r_in[ii] = in[i];
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(static_cast<float>(r_in[ii]));
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) out[i] = r_out[ii];
}
}
}
if (!finite) *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
}
};

void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, float scale,
cudaStream_t stream) {
Expand All @@ -114,6 +183,18 @@ void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag,
NVTE_CHECK_CUDA(cudaGetLastError());
}

void multi_tensor_scale_tensor_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, float *scale,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[1][0]->dtype(), g_in_type,
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
ScalePtrFunctor<p_in_type, g_in_type>(), stream, scale);))
NVTE_CHECK_CUDA(cudaGetLastError());
}

} // namespace multi_tensor_scale
} // namespace transformer_engine

Expand All @@ -127,3 +208,17 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, stream);
}

void nvte_multi_tensor_scale_tensor_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor scale,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_scale_tensor_cuda);
using namespace transformer_engine;

Tensor *scale_tensor = convertNVTETensorCheck(scale);
multi_tensor_scale::multi_tensor_scale_tensor_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
reinterpret_cast<float *>(scale_tensor->data.dptr), stream);
}
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,10 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float scale);

void multi_tensor_scale_tensor_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor scale);

std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
Expand Down
13 changes: 13 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,17 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
num_tensors, scale, at::cuda::getCurrentCUDAStream());
}

void multi_tensor_scale_tensor_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor scale) {
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto scale_cu = makeTransformerEngineTensor(scale);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
std::cout << "multi_tensor_scale_cuda TENSOR\n";
nvte_multi_tensor_scale_tensor_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(),
num_lists, num_tensors, scale_cu.data(),
at::cuda::getCurrentCUDAStream());
}

} // namespace transformer_engine::pytorch
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &transformer_engine::pytorch::multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_scale_tensor", &transformer_engine::pytorch::multi_tensor_scale_tensor_cuda,
"Fused overflow check + scale for a list of contiguous tensors with scale passed as tensor",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_l2norm", &transformer_engine::pytorch::multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors",
py::call_guard<py::gil_scoped_release>());
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""Fused optimizers and multi-tensor kernels."""
from transformer_engine_torch import (
multi_tensor_scale,
multi_tensor_scale_tensor,
multi_tensor_l2norm,
multi_tensor_unscale_l2norm,
multi_tensor_adam,
Expand Down