Skip to content

Commit 97befe6

Browse files
committed
issue/889 - revert embedding modifications
1 parent 6f8a443 commit 97befe6

7 files changed

Lines changed: 63 additions & 193 deletions

File tree

include/infinicore/ops.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include "ops/add_rms_norm.hpp"
55
#include "ops/attention.hpp"
66
#include "ops/causal_softmax.hpp"
7-
#include "ops/embedding.hpp"
87
#include "ops/flash_attention.hpp"
98
#include "ops/kv_caching.hpp"
109
#include "ops/matmul.hpp"

include/infinicore/ops/embedding.hpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,6 @@
44

55
namespace infinicore::op {
66

7-
class Embedding {
8-
public:
9-
using schema = void (*)(Tensor, Tensor, Tensor);
10-
static void execute(Tensor out, Tensor input, Tensor weight);
11-
static common::OpDispatcher<schema> &dispatcher();
12-
};
13-
147
Tensor embedding(Tensor input, Tensor weight);
158
void embedding_(Tensor out, Tensor input, Tensor weight);
169
} // namespace infinicore::op

include/infiniop.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include "infiniop/ops/clip.h"
1010
#include "infiniop/ops/conv.h"
1111
#include "infiniop/ops/dequantize_awq.h"
12-
#include "infiniop/ops/embedding.h"
1312
#include "infiniop/ops/flash_attention.h"
1413
#include "infiniop/ops/gelu.h"
1514
#include "infiniop/ops/gemm.h"

include/infiniop/ops/embedding.h

Lines changed: 0 additions & 25 deletions
This file was deleted.
Lines changed: 63 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,15 @@
11
#include "infinicore/ops/embedding.hpp"
2-
#include "../../utils.hpp"
32
#include "infinicore/context/context.hpp"
43
#include <cstring>
5-
#include <stdexcept>
64

75
namespace infinicore::op {
86

9-
common::OpDispatcher<Embedding::schema> &Embedding::dispatcher() {
10-
static common::OpDispatcher<Embedding::schema> dispatcher_;
11-
return dispatcher_;
12-
}
13-
14-
void Embedding::execute(Tensor out, Tensor input, Tensor weight) {
15-
// Check that all tensors are on the same device
16-
// This is critical: if input is on CPU while out/weight are on GPU,
17-
// passing CPU pointer to CUDA kernel will cause memory access errors
18-
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, input, weight);
19-
20-
// Set device context
21-
infinicore::context::setDevice(out->device());
22-
23-
// Use dispatcher to lookup kernel (infiniop implementation)
24-
dispatcher().lookup(out->device().getType())(out, input, weight);
25-
}
26-
277
Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the indices to extract
288
Tensor weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1
299
) {
3010
auto input_shape = input->shape();
3111
auto weight_shape = weight->shape();
12+
// auto vocab_size = weight_shape[0];
3213
auto embedding_dim = weight_shape[1];
3314

3415
// Assign memory to out variables
@@ -41,7 +22,68 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i
4122
}
4223

4324
void embedding_(Tensor out, Tensor input, Tensor weight) {
44-
Embedding::execute(out, input, weight);
25+
assert(infinicore::DataType::I64 == input->dtype() || (infinicore::DataType::I32 == input->dtype()));
26+
assert(infinicore::Device::Type::CPU == input->device().getType());
27+
28+
auto input_shape = input->shape();
29+
auto weight_shape = weight->shape();
30+
auto embedding_dim = weight_shape[1];
31+
32+
// Calculate the number of token
33+
Size counts = 1;
34+
for (auto &v : input_shape) {
35+
counts *= v;
36+
}
37+
38+
// the bytes of one token
39+
const Size bytes = dsize(weight->dtype()) * embedding_dim;
40+
auto *weight_ptr = weight->data();
41+
auto *out_ptr = out->data();
42+
43+
// copies
44+
if (weight->device().getType() == Device::Type::CPU) {
45+
if (infinicore::DataType::I64 == input->dtype()) {
46+
const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data());
47+
for (Size i = 0; i < counts; ++i) {
48+
int64_t idx = input_arr[i];
49+
assert((idx >= 0) && (idx < weight_shape[0]));
50+
std::memcpy(out_ptr + i * bytes,
51+
weight_ptr + idx * bytes,
52+
bytes);
53+
}
54+
} else if (infinicore::DataType::I32 == input->dtype()) {
55+
const int32_t *input_arr = reinterpret_cast<const int32_t *>(input->data());
56+
57+
for (Size i = 0; i < counts; ++i) {
58+
int32_t idx = input_arr[i];
59+
assert((idx >= 0) && (idx < weight_shape[0]));
60+
std::memcpy(out_ptr + i * bytes,
61+
weight_ptr + idx * bytes,
62+
bytes);
63+
}
64+
}
65+
66+
} else {
67+
if (infinicore::DataType::I64 == input->dtype()) {
68+
const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data());
69+
for (Size i = 0; i < counts; ++i) {
70+
int64_t idx = input_arr[i];
71+
assert((idx >= 0) && (idx < weight_shape[0]));
72+
context::memcpyD2D(out_ptr + i * bytes,
73+
weight_ptr + idx * bytes,
74+
bytes);
75+
}
76+
} else if (infinicore::DataType::I32 == input->dtype()) {
77+
const int32_t *input_arr = reinterpret_cast<const int32_t *>(input->data());
78+
for (Size i = 0; i < counts; ++i) {
79+
int32_t idx = input_arr[i];
80+
assert((idx >= 0) && (idx < weight_shape[0]));
81+
context::memcpyD2D(out_ptr + i * bytes,
82+
weight_ptr + idx * bytes,
83+
bytes);
84+
}
85+
}
86+
}
4587
}
4688

4789
} // namespace infinicore::op

src/infinicore/ops/embedding/embedding_infiniop.cc

Lines changed: 0 additions & 49 deletions
This file was deleted.

src/infiniop/ops/embedding/operator.cc

Lines changed: 0 additions & 89 deletions
This file was deleted.

0 commit comments

Comments
 (0)