[docs]classBaseMetric(ABC):metrics={}def__init__(self,config:MetricCfg,task_runtime:TaskRuntime):self.config=configself.name=config.nameself.task_runtime=task_runtimeself.metric_config=config.metric_config@abstractmethoddefreset(self):raiseNotImplementedError(f'`reset` function of {self.name} is not implemented')
[docs]@abstractmethoddefupdate(self,*args):""" This function is called at each world step. """raiseNotImplementedError(f'`update` function of {self.name} is not implemented')
[docs]@abstractmethoddefcalc(self):""" This function is called to calculate the metrics when the episode is terminated. """raiseNotImplementedError(f'`calc` function of {self.name} is not implemented')
[docs]@classmethoddefregister(cls,name:str):""" This function is used to register a metric class.(decorator) Args: name(str): name of the metric """defdecorator(metric_class):cls.metrics[name]=metric_class@wraps(metric_class)defwrapped_function(*args,**kwargs):returnmetric_class(*args,**kwargs)returnwrapped_functionreturndecorator
defcreate_metric(config:MetricCfg,task_runtime:TaskRuntime):ifconfig.typenotinBaseMetric.metrics:raiseKeyError(f"""The metric {config.type} is not registered, please register it using `@BaseMetric.register`""")metric_cls=BaseMetric.metrics[config.type]returnmetric_cls(config,task_runtime)