Skip to content

Commit 20401dc

Browse files
committed
feat: moved operator generation to parse
1 parent 160da40 commit 20401dc

5 files changed

Lines changed: 126 additions & 59 deletions

File tree

docs_sphinx/chapters/einsum_trees.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ This structure holds one node of the tree, its possible children, dimension size
4242
struct EinsumNode
4343
{
4444
NodeType type;
45-
float *tensor;
45+
int32_t input_tensor_index = -1;
46+
float *tensor = nullptr;
47+
mini_jit::TensorOperation tensor_op;
4648
4749
// Always filled — dims of the output tensor
4850
std::vector<int64_t> output_dim_ids;
@@ -75,12 +77,12 @@ This structure holds one node of the tree, its possible children, dimension size
7577
std::string _to_string(uint depth, std::string connection, std::string depthString) const;
7678
};
7779
78-
Then, we implemented the logic to parse the string into a set of nodes in the ``parse_tree_no_optimization()`` method. This method also indicates whether
80+
Then, we implemented the logic to parse the string into a set of nodes in the ``parse_tree_no_optimization(bool)`` method. This method also indicates whether
7981
the parsing was successful, ``ErrorParse``.
8082

8183
.. code-block:: cpp
8284
83-
ErrorParse parse_tree_no_optimization();
85+
ErrorParse parse_tree_no_optimization(bool build_operators);
8486
8587
// AND
8688

docs_sphinx/submissions/report_25_06_12.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,12 @@ This structure holds one node of the tree, its possible children, dimension size
8787
std::string _to_string(uint depth, std::string connection, std::string depthString) const;
8888
};
8989
90-
Then, we implemented the logic to parse the string into a set of nodes in the ``parse_tree_no_optimization()`` method. This method also indicates whether
90+
Then, we implemented the logic to parse the string into a set of nodes in the ``parse_tree_no_optimization(bool)`` method. This method also indicates whether
9191
the parsing was successful, ``ErrorParse``.
9292

9393
.. code-block:: cpp
9494
95-
ErrorParse parse_tree_no_optimization();
95+
ErrorParse parse_tree_no_optimization(bool build_operators);
9696
9797
// AND
9898

src/main/EinsumTree.cpp

Lines changed: 88 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ mini_jit::EinsumTree::EinsumTree(const std::string &tree_str, const std::vector<
1414
{
1515
}
1616

17-
mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::parse_tree_no_optimization()
17+
mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::parse_tree_no_optimization(bool build_operators)
1818
{
1919
if (root != nullptr)
2020
{
@@ -27,6 +27,16 @@ mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::parse_tree_no_optimizatio
2727
tensorIndex = 0;
2828
assign_tensor_indices(root);
2929

30+
if (error_parse != ErrorParse::None)
31+
{
32+
return error_parse;
33+
}
34+
35+
if (build_operators)
36+
{
37+
error_parse = generate_operators();
38+
}
39+
3040
return error_parse;
3141
}
3242

@@ -496,20 +506,7 @@ mini_jit::EinsumTree::ErrorExecute mini_jit::EinsumTree::execute_node(const std:
496506
node->tensor = new float[node->get_size(dim_sizes)]();
497507
}
498508

499-
mini_jit::TensorOperation tensor_op;
500-
TensorConfig config = lower_node(node);
501-
TensorOperation::error_t error_setup = tensor_op.setup(config);
502-
error = parse_setup_error(error_setup);
503-
504-
if (error != ErrorExecute::None)
505-
{
506-
return error;
507-
}
508-
509-
#ifdef SAVE_JITS_TO_FILE
510-
tensor_op.write_kernel_to_file(node->name());
511-
#endif // SAVE_JITS_TO_FILE
512-
tensor_op.execute(node->left->tensor, nullptr, node->tensor);
509+
node->tensor_op.execute(node->left->tensor, nullptr, node->tensor);
513510
}
514511
else if (node->type == NodeType::Contraction)
515512
{
@@ -539,20 +536,7 @@ mini_jit::EinsumTree::ErrorExecute mini_jit::EinsumTree::execute_node(const std:
539536
node->tensor = new float[node->get_size(dim_sizes)]();
540537
}
541538

542-
mini_jit::TensorOperation tensor_op;
543-
TensorConfig config = lower_node(node);
544-
TensorOperation::error_t error_setup = tensor_op.setup(config);
545-
error = parse_setup_error(error_setup);
546-
547-
if (error != ErrorExecute::None)
548-
{
549-
return error;
550-
}
551-
552-
#ifdef SAVE_JITS_TO_FILE
553-
tensor_op.write_kernel_to_file(node->name());
554-
#endif // SAVE_JITS_TO_FILE
555-
tensor_op.execute(node->left->tensor, node->right->tensor, node->tensor);
539+
node->tensor_op.execute(node->left->tensor, node->right->tensor, node->tensor);
556540
}
557541
else
558542
{
@@ -598,19 +582,19 @@ std::vector<int64_t> mini_jit::EinsumTree::get_output_dims(const std::vector<int
598582
return dims;
599583
}
600584

601-
mini_jit::EinsumTree::ErrorExecute mini_jit::EinsumTree::parse_setup_error(TensorOperation::error_t error)
585+
mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::parse_setup_error(TensorOperation::error_t error)
602586
{
603587
if (error == TensorOperation::error_t::success)
604588
{
605-
return ErrorExecute::None;
589+
return ErrorParse::None;
606590
}
607591

608592
uint32_t error_num = static_cast<uint32_t>(error) + 100;
609593

610594
release_assert(error_num >= 101, "Expected error_num to be larger equal than 101.");
611595
release_assert(error_num <= 115, "Expected error_num to be less equal than 115.");
612596

613-
return static_cast<ErrorExecute>(error_num);
597+
return static_cast<ErrorParse>(error_num);
614598
}
615599

616600
bool mini_jit::EinsumTree::is_unit_stride_n(EinsumNode *node)
@@ -673,7 +657,7 @@ void mini_jit::EinsumTree::conditional_swap(mini_jit::EinsumTree::EinsumNode *no
673657

674658
mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::parse_tree()
675659
{
676-
ErrorParse error = parse_tree_no_optimization();
660+
ErrorParse error = parse_tree_no_optimization(false);
677661

678662
if (error != ErrorParse::None)
679663
{
@@ -682,6 +666,77 @@ mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::parse_tree()
682666

683667
optimize(root);
684668

669+
error = generate_operators();
670+
671+
return error;
672+
}
673+
674+
mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::generate_operators()
675+
{
676+
if (root == nullptr)
677+
{
678+
std::cerr << "EinsumTree: Cannot execute, root is null." << std::endl;
679+
return ErrorParse::InvalidRoot;
680+
}
681+
682+
ErrorParse error = ErrorParse::None;
683+
std::vector<EinsumNode *> stack = {root};
684+
685+
while (stack.size() > 0)
686+
{
687+
EinsumNode *node = stack.back();
688+
stack.pop_back();
689+
690+
if (node->type == NodeType::Leaf)
691+
{
692+
continue;
693+
}
694+
else if (node->type == NodeType::Transposition)
695+
{
696+
release_assert(node->left != nullptr, "Expected the left child of contraction to be a valid pointer.");
697+
release_assert(node->right == nullptr, "Expected the right child of contraction to be a nullptr.");
698+
699+
stack.push_back(node->left);
700+
701+
TensorConfig config = lower_node(node);
702+
TensorOperation::error_t error_setup = node->tensor_op.setup(config);
703+
error = parse_setup_error(error_setup);
704+
705+
if (error != ErrorParse::None)
706+
{
707+
return error;
708+
}
709+
710+
#ifdef SAVE_JITS_TO_FILE
711+
node->tensor_op.write_kernel_to_file(node->name());
712+
#endif // SAVE_JITS_TO_FILE
713+
}
714+
else if (node->type == NodeType::Contraction)
715+
{
716+
release_assert(node->left != nullptr, "Expected the left child of contraction to be a valid pointer.");
717+
release_assert(node->right != nullptr, "Expected the right child of contraction to be a valid pointer.");
718+
719+
stack.push_back(node->left);
720+
stack.push_back(node->right);
721+
722+
TensorConfig config = lower_node(node);
723+
TensorOperation::error_t error_setup = node->tensor_op.setup(config);
724+
error = parse_setup_error(error_setup);
725+
726+
if (error != ErrorParse::None)
727+
{
728+
return error;
729+
}
730+
731+
#ifdef SAVE_JITS_TO_FILE
732+
node->tensor_op.write_kernel_to_file(node->name());
733+
#endif // SAVE_JITS_TO_FILE
734+
}
735+
else
736+
{
737+
release_assert(false, "Found unhandled einsum tree node type.");
738+
}
739+
}
685740
return ErrorParse::None;
686741
}
687742

src/main/EinsumTree.h

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,7 @@ namespace mini_jit
2424
ExpectedDimensionList = 5,
2525
NotAllowedToParseAgain = 6,
2626
UndefinedNode = 7,
27-
};
28-
29-
enum class ErrorExecute
30-
{
31-
None = 0,
32-
InvalidRoot = 1,
33-
NotEnoughInputTensors = 2,
34-
TooManyInputTensors = 3,
35-
NullPtrAsInputTensor = 5,
27+
InvalidRoot = 8,
3628

3729
err_wrong_dtype = 101,
3830
err_wrong_dimension = 102,
@@ -51,6 +43,15 @@ namespace mini_jit
5143
err_shared_required_for_parallel_execution = 115,
5244
};
5345

46+
enum class ErrorExecute
47+
{
48+
None = 0,
49+
InvalidRoot = 1,
50+
NotEnoughInputTensors = 2,
51+
TooManyInputTensors = 3,
52+
NullPtrAsInputTensor = 5,
53+
};
54+
5455
enum class NodeType
5556
{
5657
Leaf,
@@ -63,6 +64,7 @@ namespace mini_jit
6364
NodeType type;
6465
int32_t input_tensor_index = -1;
6566
float *tensor = nullptr;
67+
mini_jit::TensorOperation tensor_op;
6668

6769
// Always filled — dims of the output tensor
6870
std::vector<int64_t> output_dim_ids;
@@ -198,7 +200,7 @@ namespace mini_jit
198200
* @param error The error code from TensorOperation.
199201
* @return An ErrorExecute enum representing the parsed error.
200202
*/
201-
ErrorExecute parse_setup_error(TensorOperation::error_t error);
203+
ErrorParse parse_setup_error(TensorOperation::error_t error);
202204

203205
// Cleanup
204206
/**
@@ -231,6 +233,13 @@ namespace mini_jit
231233
*/
232234
int32_t findMDim(EinsumNode *Node);
233235

236+
/**
237+
* @brief Generates the operator to the parsed einsum tree.
238+
*
239+
* @return ErrorParse indicating the result of the parsing operation.
240+
*/
241+
ErrorParse generate_operators();
242+
234243
public:
235244
EinsumTree(const std::string &tree_str);
236245
EinsumTree(const std::string &tree_str, const std::vector<int64_t> &sorted_dim_sizes);
@@ -248,9 +257,10 @@ namespace mini_jit
248257
/**
249258
* Parses the einsum tree string and builds the tree structure.
250259
*
260+
* @param build_operators indicates if the operators should be generate with the parse.
251261
* @return ErrorParse indicating the result of the parsing operation.
252262
*/
253-
ErrorParse parse_tree_no_optimization();
263+
ErrorParse parse_tree_no_optimization(bool build_operators = true);
254264

255265
/**
256266
* Parses the einsum tree string, builds the tree structure and optimizes the tree.

src/test/EinsumTree.test.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ TEST_CASE("Test einsum tree parser simple example", "[einsumtree][parse][correct
2424

2525
EinsumTree tree(tree_str, sorted_dim_sizes);
2626

27-
mini_jit::EinsumTree::ErrorParse err = tree.parse_tree_no_optimization();
27+
mini_jit::EinsumTree::ErrorParse err = tree.parse_tree_no_optimization(false);
2828
REQUIRE(err == mini_jit::EinsumTree::ErrorParse::None);
2929
REQUIRE(tree.get_root() != nullptr);
3030
INFO(tree.get_root()->to_string());
@@ -47,7 +47,7 @@ TEST_CASE("Test einsum tree parser first example", "[einsumtree][parse][correctn
4747

4848
EinsumTree tree(tree_str, sorted_dim_sizes);
4949

50-
mini_jit::EinsumTree::ErrorParse err = tree.parse_tree_no_optimization();
50+
mini_jit::EinsumTree::ErrorParse err = tree.parse_tree_no_optimization(false);
5151
REQUIRE(err == mini_jit::EinsumTree::ErrorParse::None);
5252
REQUIRE(tree.get_root() != nullptr);
5353
INFO(tree.get_root()->to_string());
@@ -70,7 +70,7 @@ TEST_CASE("Test einsum tree parser second example", "[einsumtree][parse][correct
7070

7171
EinsumTree tree(tree_str, sorted_dim_sizes);
7272

73-
mini_jit::EinsumTree::ErrorParse err_parse = tree.parse_tree_no_optimization();
73+
mini_jit::EinsumTree::ErrorParse err_parse = tree.parse_tree_no_optimization(false);
7474
REQUIRE(err_parse == mini_jit::EinsumTree::ErrorParse::None);
7575
REQUIRE(tree.get_root() != nullptr);
7676
INFO(tree.get_root()->to_string());
@@ -236,7 +236,7 @@ TEST_CASE("Test einsum tree optimize swap", "[einsumtree][optimize][correctness]
236236

237237
EinsumTree tree(tree_str, sorted_dim_sizes);
238238

239-
mini_jit::EinsumTree::ErrorParse err = tree.parse_tree_no_optimization();
239+
mini_jit::EinsumTree::ErrorParse err = tree.parse_tree_no_optimization(false);
240240
REQUIRE(err == mini_jit::EinsumTree::ErrorParse::None);
241241
REQUIRE(tree.get_root() != nullptr);
242242
INFO(tree.get_root()->to_string());
@@ -264,7 +264,7 @@ TEST_CASE("Test einsum tree optimize reorder left", "[einsumtree][optimize][corr
264264

265265
EinsumTree tree(tree_str, sorted_dim_sizes);
266266

267-
mini_jit::EinsumTree::ErrorParse err = tree.parse_tree_no_optimization();
267+
mini_jit::EinsumTree::ErrorParse err = tree.parse_tree_no_optimization(false);
268268
REQUIRE(err == mini_jit::EinsumTree::ErrorParse::None);
269269
REQUIRE(tree.get_root() != nullptr);
270270
INFO(tree.get_root()->to_string());
@@ -300,7 +300,7 @@ TEST_CASE("Test einsum tree optimize reorder right", "[einsumtree][optimize][cor
300300

301301
EinsumTree tree(tree_str, sorted_dim_sizes);
302302

303-
mini_jit::EinsumTree::ErrorParse err = tree.parse_tree_no_optimization();
303+
mini_jit::EinsumTree::ErrorParse err = tree.parse_tree_no_optimization(false);
304304
REQUIRE(err == mini_jit::EinsumTree::ErrorParse::None);
305305
REQUIRE(tree.get_root() != nullptr);
306306
INFO(tree.get_root()->to_string());
@@ -338,7 +338,7 @@ TEST_CASE("Test einsum tree optimize", "[einsumtree][optimize][correctness]")
338338

339339
EinsumTree tree(tree_str, sorted_dim_sizes);
340340

341-
mini_jit::EinsumTree::ErrorParse err = tree.parse_tree_no_optimization();
341+
mini_jit::EinsumTree::ErrorParse err = tree.parse_tree_no_optimization(false);
342342
REQUIRE(err == mini_jit::EinsumTree::ErrorParse::None);
343343
REQUIRE(tree.get_root() != nullptr);
344344
INFO(tree.get_root()->to_string());
@@ -420,7 +420,7 @@ TEST_CASE("Test einsum tree optimize and execute first example", "[einsumtree][e
420420
" └─ 0,5\n";
421421
REQUIRE_THAT(expected_optimization, Catch::Matchers::Equals(tree.get_root()->to_string(), Catch::CaseSensitive::Yes));
422422

423-
tree_no_optimization.parse_tree_no_optimization();
423+
tree_no_optimization.parse_tree_no_optimization(false);
424424
INFO("No Optimization");
425425
INFO(tree_no_optimization.get_root()->to_string());
426426

@@ -483,7 +483,7 @@ TEST_CASE("Test einsum tree optimize and execute second example", "[einsumtree][
483483
" └─ 1,4,7,8\n";
484484
REQUIRE_THAT(expected_optimization, Catch::Matchers::Equals(tree.get_root()->to_string(), Catch::CaseSensitive::Yes));
485485

486-
tree_no_optimization.parse_tree_no_optimization();
486+
tree_no_optimization.parse_tree_no_optimization(false);
487487
INFO("No Optimization");
488488
INFO(tree_no_optimization.get_root()->to_string());
489489

@@ -550,7 +550,7 @@ TEST_CASE("Test einsum tree optimize and execute third example", "[einsumtree][e
550550

551551
REQUIRE_THAT(expected_optimization, Catch::Matchers::Equals(tree.get_root()->to_string(), Catch::CaseSensitive::Yes));
552552

553-
tree_no_optimization.parse_tree_no_optimization();
553+
tree_no_optimization.parse_tree_no_optimization(false);
554554
INFO("No Optimization");
555555
INFO(tree_no_optimization.get_root()->to_string());
556556

0 commit comments

Comments
 (0)