diff --git a/transformer_engine/common/include/transformer_engine/multi_tensor.h b/transformer_engine/common/include/transformer_engine/multi_tensor.h index 1bea4cb21f..73616ce88e 100644 --- a/transformer_engine/common/include/transformer_engine/multi_tensor.h +++ b/transformer_engine/common/include/transformer_engine/multi_tensor.h @@ -244,6 +244,23 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens const size_t num_tensor_lists, const size_t num_tensors_per_list, float scale, cudaStream_t stream); +/*! \brief Check overflow and scale a list of tensors. scale is tensor input. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] scale Tensor for the scaling operation. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_multi_tensor_scale_tensor_cuda(int chunk_size, NVTETensor noop_flag, + NVTETensor **tensor_lists, const size_t num_tensor_lists, + const size_t num_tensors_per_list, NVTETensor scale, + cudaStream_t stream); + /*! \brief Check overflow and scale a list of tensors. * * \warning This API is **experimental** and subject to change. diff --git a/transformer_engine/common/multi_tensor/scale.cu b/transformer_engine/common/multi_tensor/scale.cu index b3266200c4..c68b935ce5 100644 --- a/transformer_engine/common/multi_tensor/scale.cu +++ b/transformer_engine/common/multi_tensor/scale.cu @@ -102,6 +102,75 @@ struct ScaleFunctor { } }; +template +struct ScalePtrFunctor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, + TensorListMetadata<2> &tl, // NOLINT(*) + float *scale_ptr) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + float scale = *scale_ptr; + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + in_t *in = reinterpret_cast(tl.addresses[0][tensor_loc]); + in += chunk_idx * chunk_size; + + out_t *out = reinterpret_cast(tl.addresses[1][tensor_loc]); + out += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + bool finite = true; + in_t r_in[ILP]; + out_t r_out[ILP]; + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out)) { + for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_in, in, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_out[ii] = static_cast(r_in[ii]) * scale; + finite = finite && isfinite(static_cast(r_in[ii])); + } + // store + load_store(out, r_out, i_start, 0); + } + } else { + // Non-divergent exit condition for __syncthreads, not necessary here + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_in[ii] = 0.f; + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) r_in[ii] = in[i]; + } + // note for clarification to future michael: + // From a pure memory dependency perspective, there's likely no point unrolling + // the write loop, since writes just fire off once their LDGs arrive. + // Put another way, the STGs are dependent on the LDGs, but not on each other. + // There is still compute ILP benefit from unrolling the loop though. +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_out[ii] = static_cast(r_in[ii]) * scale; + finite = finite && isfinite(static_cast(r_in[ii])); + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) out[i] = r_out[ii]; + } + } + } + if (!finite) *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. + } +}; + void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, float scale, cudaStream_t stream) { @@ -114,6 +183,18 @@ void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag, NVTE_CHECK_CUDA(cudaGetLastError()); } +void multi_tensor_scale_tensor_cuda(int chunk_size, Tensor noop_flag, + std::vector> tensor_lists, float *scale, + cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + tensor_lists[0][0]->dtype(), p_in_type, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + tensor_lists[1][0]->dtype(), g_in_type, + multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + ScalePtrFunctor(), stream, scale);)) + NVTE_CHECK_CUDA(cudaGetLastError()); +} + } // namespace multi_tensor_scale } // namespace transformer_engine @@ -127,3 +208,17 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, stream); } + +void nvte_multi_tensor_scale_tensor_cuda(int chunk_size, NVTETensor noop_flag, + NVTETensor **tensor_lists, const size_t num_tensor_lists, + const size_t num_tensors_per_list, NVTETensor scale, + cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_scale_tensor_cuda); + using namespace transformer_engine; + + Tensor *scale_tensor = convertNVTETensorCheck(scale); + multi_tensor_scale::multi_tensor_scale_tensor_cuda( + chunk_size, *convertNVTETensorCheck(noop_flag), + convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), + reinterpret_cast(scale_tensor->data.dptr), stream); +} diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 52ef02a347..656ec299ca 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -411,6 +411,10 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, float scale); +void multi_tensor_scale_tensor_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor scale); + std::tuple multi_tensor_l2norm_cuda( int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::optional per_tensor_python); diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp index 4bb83bfeed..de957a901a 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp @@ -18,4 +18,17 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, num_tensors, scale, at::cuda::getCurrentCUDAStream()); } +void multi_tensor_scale_tensor_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor scale) { + auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); + auto scale_cu = makeTransformerEngineTensor(scale); + auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = + makeTransformerEngineTensorList(tensor_lists); + std::cout << "multi_tensor_scale_cuda TENSOR\n"; + nvte_multi_tensor_scale_tensor_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), + num_lists, num_tensors, scale_cu.data(), + at::cuda::getCurrentCUDAStream()); +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index e73eca7861..9fb91ecd09 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -402,6 +402,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_scale", &transformer_engine::pytorch::multi_tensor_scale_cuda, "Fused overflow check + scale for a list of contiguous tensors", py::call_guard()); + m.def("multi_tensor_scale_tensor", &transformer_engine::pytorch::multi_tensor_scale_tensor_cuda, + "Fused overflow check + scale for a list of contiguous tensors with scale passed as tensor", + py::call_guard()); m.def("multi_tensor_l2norm", &transformer_engine::pytorch::multi_tensor_l2norm_cuda, "Computes L2 norm for a list of contiguous tensors", py::call_guard()); diff --git a/transformer_engine/pytorch/optimizers/__init__.py b/transformer_engine/pytorch/optimizers/__init__.py index 792eab094a..7220f1924a 100644 --- a/transformer_engine/pytorch/optimizers/__init__.py +++ b/transformer_engine/pytorch/optimizers/__init__.py @@ -5,6 +5,7 @@ """Fused optimizers and multi-tensor kernels.""" from transformer_engine_torch import ( multi_tensor_scale, + multi_tensor_scale_tensor, multi_tensor_l2norm, multi_tensor_unscale_l2norm, multi_tensor_adam,