Source code for udaan.models.base
import numpy as np
from ...core.defaults import GRAVITY
from ...utils.logging import get_logger
_logger = get_logger(__name__)
[docs]
class BaseModel:
[docs]
def __init__(self, **kwargs):
self._g = GRAVITY
self._ge3 = np.array([0.0, 0.0, self._g])
self._e1 = np.array([1.0, 0.0, 0.0])
self._e2 = np.array([0.0, 1.0, 0.0])
self._e3 = np.array([0.0, 0.0, 1.0])
self._gravity = np.array([0.0, 0.0, -self._g])
self.sim_timestep = 0.002
self._n_state = 0
self._n_action = 0
self.t = 0.0
self.verbose = False
self.render = False
# matched disturbance, i.e., disturbance if added to the input before updating the dynamics
self.disturbance = False
return
def _parse_args(self, **kwargs):
for key, value in kwargs.items():
if key in self.__dict__.keys():
if type(self.__dict__[key]) is dict:
self.__dict__[key].update(value)
else:
self.__dict__[key] = value
else:
_logger.warning(f"Key {key} not found in environment")
[docs]
def reset(self):
raise NotImplementedError
[docs]
def step(self, action):
raise NotImplementedError
[docs]
def simulate(self, **kwargs):
raise NotImplementedError
[docs]
def get_action_size(self):
return self._n_action
[docs]
def get_state_size(self):
return self._n_state
from .floating_pointmass import FloatingPointmass as FloatingPointmass
from .multi_pointmass_suspended_payload import (
MultiPointmassSuspendedPayload as MultiPointmassSuspendedPayload,
)
from .pointmass_suspended_payload import (
PointmassSuspendedPayload as PointmassSuspendedPayload,
)
from .s2_pendulum import S2Pendulum as S2Pendulum
__all__ = [
"BaseModel",
"FloatingPointmass",
"MultiPointmassSuspendedPayload",
"PointmassSuspendedPayload",
"S2Pendulum",
]