11#include " infinicore/ops/embedding.hpp"
2- #include " ../../utils.hpp"
32#include " infinicore/context/context.hpp"
43#include < cstring>
5- #include < stdexcept>
64
75namespace infinicore ::op {
86
9- common::OpDispatcher<Embedding::schema> &Embedding::dispatcher () {
10- static common::OpDispatcher<Embedding::schema> dispatcher_;
11- return dispatcher_;
12- }
13-
14- void Embedding::execute (Tensor out, Tensor input, Tensor weight) {
15- // Check that all tensors are on the same device
16- // This is critical: if input is on CPU while out/weight are on GPU,
17- // passing CPU pointer to CUDA kernel will cause memory access errors
18- INFINICORE_ASSERT_TENSORS_SAME_DEVICE (out, input, weight);
19-
20- // Set device context
21- infinicore::context::setDevice (out->device ());
22-
23- // Use dispatcher to lookup kernel (infiniop implementation)
24- dispatcher ().lookup (out->device ().getType ())(out, input, weight);
25- }
26-
277Tensor embedding (Tensor input, // LongTensor of arbitrary shape containing the indices to extract
288 Tensor weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1
299) {
3010 auto input_shape = input->shape ();
3111 auto weight_shape = weight->shape ();
12+ // auto vocab_size = weight_shape[0];
3213 auto embedding_dim = weight_shape[1 ];
3314
3415 // Assign memory to out variables
@@ -41,7 +22,68 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i
4122}
4223
4324void embedding_ (Tensor out, Tensor input, Tensor weight) {
44- Embedding::execute (out, input, weight);
25+ assert (infinicore::DataType::I64 == input->dtype () || (infinicore::DataType::I32 == input->dtype ()));
26+ assert (infinicore::Device::Type::CPU == input->device ().getType ());
27+
28+ auto input_shape = input->shape ();
29+ auto weight_shape = weight->shape ();
30+ auto embedding_dim = weight_shape[1 ];
31+
32+ // Calculate the number of token
33+ Size counts = 1 ;
34+ for (auto &v : input_shape) {
35+ counts *= v;
36+ }
37+
38+ // the bytes of one token
39+ const Size bytes = dsize (weight->dtype ()) * embedding_dim;
40+ auto *weight_ptr = weight->data ();
41+ auto *out_ptr = out->data ();
42+
43+ // copies
44+ if (weight->device ().getType () == Device::Type::CPU) {
45+ if (infinicore::DataType::I64 == input->dtype ()) {
46+ const int64_t *input_arr = reinterpret_cast <const int64_t *>(input->data ());
47+ for (Size i = 0 ; i < counts; ++i) {
48+ int64_t idx = input_arr[i];
49+ assert ((idx >= 0 ) && (idx < weight_shape[0 ]));
50+ std::memcpy (out_ptr + i * bytes,
51+ weight_ptr + idx * bytes,
52+ bytes);
53+ }
54+ } else if (infinicore::DataType::I32 == input->dtype ()) {
55+ const int32_t *input_arr = reinterpret_cast <const int32_t *>(input->data ());
56+
57+ for (Size i = 0 ; i < counts; ++i) {
58+ int32_t idx = input_arr[i];
59+ assert ((idx >= 0 ) && (idx < weight_shape[0 ]));
60+ std::memcpy (out_ptr + i * bytes,
61+ weight_ptr + idx * bytes,
62+ bytes);
63+ }
64+ }
65+
66+ } else {
67+ if (infinicore::DataType::I64 == input->dtype ()) {
68+ const int64_t *input_arr = reinterpret_cast <const int64_t *>(input->data ());
69+ for (Size i = 0 ; i < counts; ++i) {
70+ int64_t idx = input_arr[i];
71+ assert ((idx >= 0 ) && (idx < weight_shape[0 ]));
72+ context::memcpyD2D (out_ptr + i * bytes,
73+ weight_ptr + idx * bytes,
74+ bytes);
75+ }
76+ } else if (infinicore::DataType::I32 == input->dtype ()) {
77+ const int32_t *input_arr = reinterpret_cast <const int32_t *>(input->data ());
78+ for (Size i = 0 ; i < counts; ++i) {
79+ int32_t idx = input_arr[i];
80+ assert ((idx >= 0 ) && (idx < weight_shape[0 ]));
81+ context::memcpyD2D (out_ptr + i * bytes,
82+ weight_ptr + idx * bytes,
83+ bytes);
84+ }
85+ }
86+ }
4587}
4688
4789} // namespace infinicore::op
0 commit comments