@@ -14,7 +14,7 @@ mini_jit::EinsumTree::EinsumTree(const std::string &tree_str, const std::vector<
1414{
1515}
1616
17- mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::parse_tree_no_optimization ()
17+ mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::parse_tree_no_optimization (bool build_operators )
1818{
1919 if (root != nullptr )
2020 {
@@ -27,6 +27,16 @@ mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::parse_tree_no_optimizatio
2727 tensorIndex = 0 ;
2828 assign_tensor_indices (root);
2929
30+ if (error_parse != ErrorParse::None)
31+ {
32+ return error_parse;
33+ }
34+
35+ if (build_operators)
36+ {
37+ error_parse = generate_operators ();
38+ }
39+
3040 return error_parse;
3141}
3242
@@ -496,20 +506,7 @@ mini_jit::EinsumTree::ErrorExecute mini_jit::EinsumTree::execute_node(const std:
496506 node->tensor = new float [node->get_size (dim_sizes)]();
497507 }
498508
499- mini_jit::TensorOperation tensor_op;
500- TensorConfig config = lower_node (node);
501- TensorOperation::error_t error_setup = tensor_op.setup (config);
502- error = parse_setup_error (error_setup);
503-
504- if (error != ErrorExecute::None)
505- {
506- return error;
507- }
508-
509- #ifdef SAVE_JITS_TO_FILE
510- tensor_op.write_kernel_to_file (node->name ());
511- #endif // SAVE_JITS_TO_FILE
512- tensor_op.execute (node->left ->tensor , nullptr , node->tensor );
509+ node->tensor_op .execute (node->left ->tensor , nullptr , node->tensor );
513510 }
514511 else if (node->type == NodeType::Contraction)
515512 {
@@ -539,20 +536,7 @@ mini_jit::EinsumTree::ErrorExecute mini_jit::EinsumTree::execute_node(const std:
539536 node->tensor = new float [node->get_size (dim_sizes)]();
540537 }
541538
542- mini_jit::TensorOperation tensor_op;
543- TensorConfig config = lower_node (node);
544- TensorOperation::error_t error_setup = tensor_op.setup (config);
545- error = parse_setup_error (error_setup);
546-
547- if (error != ErrorExecute::None)
548- {
549- return error;
550- }
551-
552- #ifdef SAVE_JITS_TO_FILE
553- tensor_op.write_kernel_to_file (node->name ());
554- #endif // SAVE_JITS_TO_FILE
555- tensor_op.execute (node->left ->tensor , node->right ->tensor , node->tensor );
539+ node->tensor_op .execute (node->left ->tensor , node->right ->tensor , node->tensor );
556540 }
557541 else
558542 {
@@ -598,19 +582,19 @@ std::vector<int64_t> mini_jit::EinsumTree::get_output_dims(const std::vector<int
598582 return dims;
599583}
600584
601- mini_jit::EinsumTree::ErrorExecute mini_jit::EinsumTree::parse_setup_error (TensorOperation::error_t error)
585+ mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::parse_setup_error (TensorOperation::error_t error)
602586{
603587 if (error == TensorOperation::error_t ::success)
604588 {
605- return ErrorExecute ::None;
589+ return ErrorParse ::None;
606590 }
607591
608592 uint32_t error_num = static_cast <uint32_t >(error) + 100 ;
609593
610594 release_assert (error_num >= 101 , " Expected error_num to be larger equal than 101." );
611595 release_assert (error_num <= 115 , " Expected error_num to be less equal than 115." );
612596
613- return static_cast <ErrorExecute >(error_num);
597+ return static_cast <ErrorParse >(error_num);
614598}
615599
616600bool mini_jit::EinsumTree::is_unit_stride_n (EinsumNode *node)
@@ -673,7 +657,7 @@ void mini_jit::EinsumTree::conditional_swap(mini_jit::EinsumTree::EinsumNode *no
673657
674658mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::parse_tree ()
675659{
676- ErrorParse error = parse_tree_no_optimization ();
660+ ErrorParse error = parse_tree_no_optimization (false );
677661
678662 if (error != ErrorParse::None)
679663 {
@@ -682,6 +666,77 @@ mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::parse_tree()
682666
683667 optimize (root);
684668
669+ error = generate_operators ();
670+
671+ return error;
672+ }
673+
674+ mini_jit::EinsumTree::ErrorParse mini_jit::EinsumTree::generate_operators ()
675+ {
676+ if (root == nullptr )
677+ {
678+ std::cerr << " EinsumTree: Cannot execute, root is null." << std::endl;
679+ return ErrorParse::InvalidRoot;
680+ }
681+
682+ ErrorParse error = ErrorParse::None;
683+ std::vector<EinsumNode *> stack = {root};
684+
685+ while (stack.size () > 0 )
686+ {
687+ EinsumNode *node = stack.back ();
688+ stack.pop_back ();
689+
690+ if (node->type == NodeType::Leaf)
691+ {
692+ continue ;
693+ }
694+ else if (node->type == NodeType::Transposition)
695+ {
696+ release_assert (node->left != nullptr , " Expected the left child of contraction to be a valid pointer." );
697+ release_assert (node->right == nullptr , " Expected the right child of contraction to be a nullptr." );
698+
699+ stack.push_back (node->left );
700+
701+ TensorConfig config = lower_node (node);
702+ TensorOperation::error_t error_setup = node->tensor_op .setup (config);
703+ error = parse_setup_error (error_setup);
704+
705+ if (error != ErrorParse::None)
706+ {
707+ return error;
708+ }
709+
710+ #ifdef SAVE_JITS_TO_FILE
711+ node->tensor_op .write_kernel_to_file (node->name ());
712+ #endif // SAVE_JITS_TO_FILE
713+ }
714+ else if (node->type == NodeType::Contraction)
715+ {
716+ release_assert (node->left != nullptr , " Expected the left child of contraction to be a valid pointer." );
717+ release_assert (node->right != nullptr , " Expected the right child of contraction to be a valid pointer." );
718+
719+ stack.push_back (node->left );
720+ stack.push_back (node->right );
721+
722+ TensorConfig config = lower_node (node);
723+ TensorOperation::error_t error_setup = node->tensor_op .setup (config);
724+ error = parse_setup_error (error_setup);
725+
726+ if (error != ErrorParse::None)
727+ {
728+ return error;
729+ }
730+
731+ #ifdef SAVE_JITS_TO_FILE
732+ node->tensor_op .write_kernel_to_file (node->name ());
733+ #endif // SAVE_JITS_TO_FILE
734+ }
735+ else
736+ {
737+ release_assert (false , " Found unhandled einsum tree node type." );
738+ }
739+ }
685740 return ErrorParse::None;
686741}
687742
0 commit comments