Shortcuts

Source code for grutopia.core.task.metric

from abc import ABC, abstractmethod
from functools import wraps

from grutopia.core.config.metric import MetricCfg
from grutopia.core.runtime.task_runtime import TaskRuntime


[docs]class BaseMetric(ABC): metrics = {} def __init__(self, config: MetricCfg, task_runtime: TaskRuntime): self.config = config self.name = config.name self.task_runtime = task_runtime self.metric_config = config.metric_config @abstractmethod def reset(self): raise NotImplementedError(f'`reset` function of {self.name} is not implemented')
[docs] @abstractmethod def update(self, *args): """ This function is called at each world step. """ raise NotImplementedError(f'`update` function of {self.name} is not implemented')
[docs] @abstractmethod def calc(self): """ This function is called to calculate the metrics when the episode is terminated. """ raise NotImplementedError(f'`calc` function of {self.name} is not implemented')
[docs] @classmethod def register(cls, name: str): """ This function is used to register a metric class.(decorator) Args: name(str): name of the metric """ def decorator(metric_class): cls.metrics[name] = metric_class @wraps(metric_class) def wrapped_function(*args, **kwargs): return metric_class(*args, **kwargs) return wrapped_function return decorator
def create_metric(config: MetricCfg, task_runtime: TaskRuntime): if config.type not in BaseMetric.metrics: raise KeyError( f"""The metric {config.type} is not registered, please register it using `@BaseMetric.register`""" ) metric_cls = BaseMetric.metrics[config.type] return metric_cls(config, task_runtime)