Source code for grutopia.core.task.reward
from abc import ABC, abstractmethod
from functools import wraps
from typing import Any, Dict
from grutopia.core.config.task import RewardCfg
from grutopia.core.task import BaseTask
[docs]class BaseReward(ABC):
rewards = {}
def __init__(self, task: BaseTask, settings: Dict[str, Any]):
self.state = None
self.task = task
self.settings = settings
self.init_setting()
def init_setting(self):
pass
@abstractmethod
def reset(self):
self.state = None
@abstractmethod
def calc(self) -> float:
raise NotImplementedError(f'`calc` function of {self.name} is not implemented')
@abstractmethod
def _calc_next_state(self):
raise NotImplementedError(f'`_calc_next_state` function of {self.name} is not implemented')
[docs] @classmethod
def register(cls, name: str):
"""
This function is used to register a reward class.(decorator)
Args:
name(str): name of the reward
"""
def decorator(reward_class):
cls.rewards[name] = reward_class
@wraps(reward_class)
def wrapped_function(*args, **kwargs):
return reward_class(*args, **kwargs)
return wrapped_function
return decorator
def create_reward(reward_config: RewardCfg, task: BaseTask):
if reward_config.reward_type not in BaseReward.rewards:
raise KeyError(
f"""The reward {reward_config.reward_type} is not registered, please register it using `@BaseReward.register`"""
)
reward_cls = BaseReward.rewards[reward_config.reward_type]
return reward_cls(task, reward_config.reward_settings)