5858
5959from optimas .generators .ax .base import AxGenerator
6060from optimas .core import (
61- TrialParameter ,
6261 Task ,
6362 Trial ,
6463 TrialStatus ,
@@ -154,9 +153,6 @@ class AxMultitaskGenerator(AxGenerator):
154153 VOCS object defining variables, objectives, constraints, and observables.
155154 lofi_task, hifi_task : Task
156155 The low- and high-fidelity tasks.
157- analyzed_parameters : list of Parameter, optional
158- List of parameters to analyze at each trial, but which are not
159- optimization objectives. By default ``None``.
160156 use_cuda : bool, optional
161157 Whether to allow the generator to run on a CUDA GPU. By default
162158 ``False``.
@@ -178,6 +174,8 @@ class AxMultitaskGenerator(AxGenerator):
178174
179175 """
180176
177+ returns_id = True
178+
181179 def __init__ (
182180 self ,
183181 vocs : VOCS ,
@@ -190,16 +188,6 @@ def __init__(
190188 model_save_period : Optional [int ] = 5 ,
191189 model_history_dir : Optional [str ] = "model_history" ,
192190 ) -> None :
193-
194- # As trial parameters these get written to history array
195- # Ax trial_index and arm toegther locate a point
196- # Multiple points (Optimas trials) can share the same Ax trial_index
197- # vocs interface note: These are not part of vocs. They are only stored
198- # to allow keeping track of them from previous runs.
199- custom_trial_parameters = [
200- TrialParameter ("arm_name" , "ax_arm_name" , dtype = "U32" ),
201- TrialParameter ("ax_trial_id" , "ax_trial_index" , dtype = int ),
202- ]
203191 self ._check_inputs (vocs , lofi_task , hifi_task )
204192
205193 super ().__init__ (
@@ -210,7 +198,6 @@ def __init__(
210198 save_model = save_model ,
211199 model_save_period = model_save_period ,
212200 model_history_dir = model_history_dir ,
213- custom_trial_parameters = custom_trial_parameters ,
214201 )
215202 self .lofi_task = lofi_task
216203 self .hifi_task = hifi_task
@@ -226,6 +213,10 @@ def __init__(
226213 self .gr_lofi = None
227214 self ._experiment = self ._create_experiment ()
228215
216+ # Internal mapping: _id -> (arm_name, ax_trial_id, trial_type)
217+ self ._id_mapping = {}
218+ self ._next_id = 0
219+
229220 def get_gen_specs (
230221 self , sim_workers : int , run_params : Dict , sim_max : int
231222 ) -> Dict :
@@ -285,11 +276,22 @@ def suggest(self, num_points: Optional[int]) -> List[dict]:
285276 if trial_param .name == "trial_type" :
286277 point [trial_param .name ] = trial_type
287278
288- point ["ax_trial_id" ] = trial_index
289- point ["arm_name" ] = arm .name
279+ # Generate unique _id and store mapping
280+ current_id = self ._next_id
281+ self ._id_mapping [current_id ] = {
282+ "ax_trial_id" : trial_index ,
283+ "arm_name" : arm .name ,
284+ }
285+ point ["_id" ] = current_id
286+ self ._next_id += 1
290287 points .append (point )
291288 return points
292289
290+ def _get_trial_mapping (self , gen_id : int ) -> Tuple [int , str ]:
291+ """Get mapping information for a trial gen_id."""
292+ mapping = self ._id_mapping [gen_id ]
293+ return mapping ["ax_trial_id" ], mapping ["arm_name" ]
294+
293295 def ingest (self , results : List [dict ]) -> None :
294296 """Incorporate evaluated trials into experiment."""
295297 # reconstruct Optimas trials
@@ -304,60 +306,77 @@ def ingest(self, results: List[dict]) -> None:
304306 )
305307 trials .append (trial )
306308
309+ # Apply _id mapping to all trials before processing
310+ for trial in trials :
311+ if trial .gen_id is not None :
312+ if trial .gen_id not in self ._id_mapping :
313+ raise ValueError (
314+ f"Trial has _id={ trial .gen_id } which is not recognized by this generator."
315+ )
316+ trial .ax_trial_id , trial .arm_name = self ._get_trial_mapping (
317+ trial .gen_id
318+ )
319+
307320 if self .gen_state == NOT_STARTED :
308321 self ._incorporate_external_data (trials )
309322 else :
310323 self ._complete_evaluations (trials )
311324
312325 def _incorporate_external_data (self , trials : List [Trial ]) -> None :
313- """Incorporate external data (e.g., from history) into experiment."""
314- # Get trial indices.
315- trial_indices = []
316- for trial in trials :
317- trial_indices .append (trial .ax_trial_id )
318- trial_indices = np .unique (np .array (trial_indices ))
319-
320- # Group trials by index.
321- grouped_trials = {}
322- for index in trial_indices :
323- grouped_trials [index ] = []
326+ """Incorporate external data (e.g., from history) into experiment.
327+
328+ Unknown/external points have no gen_id. We create new arms and add
329+ observations directly to the experiment, then let the model use them
330+ as if starting fresh.
331+ """
332+ # Group by trial_type (default to hifi if not specified)
333+ grouped_by_type = {}
324334 for trial in trials :
325- grouped_trials [ trial . ax_trial_id ]. append ( trial )
326-
327- # Add trials to experiment.
328- for index in trial_indices :
329- # Get all trials with current index.
330- trials_i = grouped_trials [ index ]
331- trial_type = trials_i [ 0 ]. trial_type
332- # Create arms.
335+ trial_type = getattr ( trial , "trial_type" , self . hifi_task . name )
336+ if trial_type not in grouped_by_type :
337+ grouped_by_type [ trial_type ] = []
338+ grouped_by_type [ trial_type ]. append ( trial )
339+
340+ param_to_name = {}
341+ arm_count = 0
342+ for trial_type , trials_i in grouped_by_type . items ():
333343 arms = []
334344 for trial in trials_i :
335345 params = {}
336346 for var , val in zip (
337347 trial .varying_parameters , trial .parameter_values
338348 ):
339349 params [var .name ] = val
340- arms .append (Arm (parameters = params , name = trial .arm_name ))
350+ arm = Arm (parameters = params )
351+ if arm .signature not in param_to_name :
352+ param_to_name [arm .signature ] = f"external_{ arm_count } "
353+ arm_count += 1
354+ arms .append (
355+ Arm (parameters = params , name = param_to_name [arm .signature ])
356+ )
357+ # self._next_id += 1
358+
341359 # Create new batch trial.
342360 gr = GeneratorRun (arms = arms , weights = [1.0 ] * len (arms ))
343361 ax_trial = self ._experiment .new_batch_trial (
344362 generator_run = gr , trial_type = trial_type
345363 )
346364 ax_trial .run ()
347365 # Incorporate observations.
348- for trial in trials_i :
366+ for i , trial in enumerate (trials_i ):
367+ arm_name = ax_trial .arms [i ].name
349368 if trial .status != TrialStatus .FAILED :
350369 objective_eval = {}
351370 oe = trial .objective_evaluations [0 ]
352371 objective_eval ["f" ] = (oe .value , oe .sem )
353- ax_trial .run_metadata [trial . arm_name ] = objective_eval
372+ ax_trial .run_metadata [arm_name ] = objective_eval
354373 else :
355- ax_trial .mark_arm_abandoned (trial . arm_name )
374+ ax_trial .mark_arm_abandoned (arm_name )
356375 # Mark batch trial as completed.
357376 ax_trial .mark_completed ()
358377 # Keep track of high-fidelity trials.
359378 if trial_type == self .hifi_task .name :
360- self .hifi_trials .append (index )
379+ self .hifi_trials .append (ax_trial . index )
361380
362381 def _complete_evaluations (self , trials : List [Trial ]) -> None :
363382 """Complete evaluated trials."""
0 commit comments