"""MuJoCo physics backend for Udaan.
Requires mujoco package.
Install with: pip install udaan
"""
import os
import time
import mujoco
import numpy as np
from ... import _FOLDER_PATH
from ...utils.logging import get_logger
_logger = get_logger(__name__)
class _GlfwViewer:
"""Lightweight GLFW-based MuJoCo viewer with mouse interaction."""
def _mouse_button_callback(self, window, button, action, mods):
import glfw
pressed = action == glfw.PRESS
if button == glfw.MOUSE_BUTTON_LEFT:
self._button_left = pressed
elif button == glfw.MOUSE_BUTTON_RIGHT:
self._button_right = pressed
elif button == glfw.MOUSE_BUTTON_MIDDLE:
self._button_middle = pressed
x, y = self._glfw.get_cursor_pos(window)
self._last_mouse_x = x
self._last_mouse_y = y
def _mouse_move_callback(self, window, xpos, ypos):
dx = xpos - self._last_mouse_x
dy = ypos - self._last_mouse_y
self._last_mouse_x = xpos
self._last_mouse_y = ypos
if not (self._button_left or self._button_right or self._button_middle):
return
width, height = self._glfw.get_window_size(window)
if self._button_right:
action = mujoco.mjtMouse.mjMOUSE_MOVE_V
elif self._button_left:
action = mujoco.mjtMouse.mjMOUSE_ROTATE_V
elif self._button_middle:
action = mujoco.mjtMouse.mjMOUSE_ZOOM
else:
return
mujoco.mjv_moveCamera(
self._model, action, dx / width, dy / height, self._scene, self._camera
)
def _scroll_callback(self, window, xoffset, yoffset):
mujoco.mjv_moveCamera(
self._model, mujoco.mjtMouse.mjMOUSE_ZOOM, 0, -0.05 * yoffset, self._scene, self._camera
)
def _key_callback(self, window, key, scancode, action, mods):
import glfw
if action == glfw.PRESS and key in (glfw.KEY_ESCAPE, glfw.KEY_Q):
glfw.set_window_should_close(window, True)
@property
def cam(self):
return self._camera
def __init__(self, model, data, width=1200, height=900, title="udaan", record=None):
if record:
width, height = 640, 480
import glfw
self._model = model
self._data = data
self._glfw = glfw
self._record_path = record
self._frames = [] if record else None
self._record_every = 4 # capture every Nth rendered frame
self._render_count = 0
if not glfw.init():
raise RuntimeError("Failed to initialize GLFW")
self._window = glfw.create_window(width, height, title, None, None)
if not self._window:
glfw.terminate()
raise RuntimeError("Failed to create GLFW window")
glfw.make_context_current(self._window)
glfw.swap_interval(1)
self._scene = mujoco.MjvScene(model, maxgeom=10000)
self._context = mujoco.MjrContext(model, mujoco.mjtFontScale.mjFONTSCALE_150)
self._camera = mujoco.MjvCamera()
self._option = mujoco.MjvOption()
self._perturb = mujoco.MjvPerturb()
mujoco.mjv_defaultCamera(self._camera)
mujoco.mjv_defaultOption(self._option)
mujoco.mjv_defaultPerturb(self._perturb)
self._camera.type = mujoco.mjtCamera.mjCAMERA_TRACKING
self._camera.trackbodyid = 1
self._camera.distance = max(model.stat.extent * 1.5, 1.5)
self._camera.elevation = -20
self._camera.azimuth = 135
# Mouse interaction state
self._button_left = False
self._button_right = False
self._button_middle = False
self._last_mouse_x = 0.0
self._last_mouse_y = 0.0
# Register GLFW callbacks
glfw.set_mouse_button_callback(self._window, self._mouse_button_callback)
glfw.set_cursor_pos_callback(self._window, self._mouse_move_callback)
glfw.set_scroll_callback(self._window, self._scroll_callback)
glfw.set_key_callback(self._window, self._key_callback)
self._last_render_time = 0.0
self._overlay_text = ""
# Visual markers: per-entity trails, starts, targets
self.show_trails = True
self._trails = {} # key -> list of positions
self._trail_colors = {} # key -> rgba
self._trail_max = 200
self._starts = {} # key -> position
self._targets = {} # key -> position
def _add_geom(self, type_, size, pos, rgba, mat=np.eye(3)):
"""Add a custom geom to the scene for the current frame."""
if self._scene.ngeom >= self._scene.maxgeom:
return
g = self._scene.geoms[self._scene.ngeom]
mujoco.mjv_initGeom(g, type_, size, pos, mat.flatten(), rgba)
self._scene.ngeom += 1
def add_trail_point(self, pos, key=0, rgba=None):
"""Append a position to a keyed trail."""
if key not in self._trails:
self._trails[key] = []
self._trail_colors[key] = (
np.array(rgba, dtype=np.float32)
if rgba is not None
else np.array([0.2, 0.6, 1.0, 0.6], dtype=np.float32)
)
trail = self._trails[key]
trail.append(np.array(pos, dtype=np.float64))
if len(trail) > self._trail_max:
self._trails[key] = trail[-self._trail_max :]
def set_start(self, pos, key=0):
self._starts[key] = np.array(pos, dtype=np.float64)
def set_target(self, pos, key=0):
self._targets[key] = np.array(pos, dtype=np.float64)
def _render_markers(self):
"""Draw trails, start markers, and target markers."""
if self.show_trails:
trail_size = np.array([0.008, 0, 0], dtype=np.float64)
for key, trail in self._trails.items():
rgba = self._trail_colors.get(key, np.array([0.2, 0.6, 1.0, 0.6], dtype=np.float32))
for pt in trail[::2]:
self._add_geom(mujoco.mjtGeom.mjGEOM_SPHERE, trail_size, pt, rgba)
marker_size = np.array([0.03, 0, 0], dtype=np.float64)
green = np.array([0.2, 0.9, 0.3, 0.8], dtype=np.float32)
red = np.array([0.9, 0.2, 0.2, 0.8], dtype=np.float32)
for pos in self._starts.values():
self._add_geom(mujoco.mjtGeom.mjGEOM_SPHERE, marker_size, pos, green)
for pos in self._targets.values():
self._add_geom(mujoco.mjtGeom.mjGEOM_SPHERE, marker_size, pos, red)
def render(self):
import glfw
if glfw.window_should_close(self._window):
return
# Always process events so window responds to close/input
glfw.poll_events()
# Throttle rendering to ~60fps
now = time.monotonic()
if now - self._last_render_time < 1.0 / 60.0:
return
self._last_render_time = now
viewport = mujoco.MjrRect(0, 0, *glfw.get_framebuffer_size(self._window))
mujoco.mjv_updateScene(
self._model,
self._data,
self._option,
self._perturb,
self._camera,
mujoco.mjtCatBit.mjCAT_ALL,
self._scene,
)
self._render_markers()
mujoco.mjr_render(viewport, self._scene, self._context)
# Overlay sim time top-right
time_str = f"t = {self._data.time:.2f}s"
mujoco.mjr_overlay(
mujoco.mjtFont.mjFONT_NORMAL,
mujoco.mjtGridPos.mjGRID_TOPRIGHT,
viewport,
time_str,
"",
self._context,
)
# Overlay legend top-left (if set)
if self._overlay_text:
mujoco.mjr_overlay(
mujoco.mjtFont.mjFONT_NORMAL,
mujoco.mjtGridPos.mjGRID_TOPLEFT,
viewport,
self._overlay_text,
"",
self._context,
)
glfw.swap_buffers(self._window)
# Capture frame for recording (every Nth rendered frame)
if self._frames is not None:
self._render_count += 1
if self._render_count % self._record_every == 0:
width, height = glfw.get_framebuffer_size(self._window)
rgb = np.empty((height, width, 3), dtype=np.uint8)
mujoco.mjr_readPixels(rgb, None, viewport, self._context)
self._frames.append(np.flipud(rgb))
def save_recording(self):
"""Save captured frames to file (GIF or MP4)."""
if not self._frames or not self._record_path:
return
import imageio.v3 as iio
path = self._record_path
frames = list(self._frames)
self._frames = [] # Prevent double-save
_logger.info("Saving %d frames to %s", len(frames), path)
if path.endswith(".gif"):
iio.imwrite(path, frames, duration=1000 // 15, loop=0)
else:
iio.imwrite(path, frames, fps=30)
_logger.info("Saved recording to %s", path)
def hold(self):
"""Keep window open until user closes it. Press ESC or Q to quit."""
import glfw
while self._window and not glfw.window_should_close(self._window):
self.render()
glfw.wait_events_timeout(1.0 / 60.0)
self.close()
def close(self):
import glfw
self.save_recording()
if self._window:
glfw.destroy_window(self._window)
glfw.terminate()
self._window = None
def is_alive(self):
import glfw
return self._window is not None and not glfw.window_should_close(self._window)
[docs]
class MujocoModel:
[docs]
def __init__(self, model_path, render=False, record=None):
self.full_path = os.path.join(_FOLDER_PATH, "udaan", "models", "assets", "mjcf", model_path)
if not os.path.exists(self.full_path):
raise OSError(f"File {self.full_path} does not exist")
self.render = render
self._record = record
self._viewer = None
self.frame_skip = 1
self._initialize_simulation()
def _initialize_simulation(self):
_logger.info("Loading model from %s", self.full_path)
self.model = mujoco.MjModel.from_xml_path(self.full_path)
self.data = mujoco.MjData(self.model)
self._wall_start = None
if self.render:
self._viewer = _GlfwViewer(self.model, self.data, record=self._record)
def _step_mujoco_simulation(self, n_frames=1):
mujoco.mj_step(self.model, self.data, n_frames)
if self.render and self._viewer is not None:
# Sync simulation to real-time
if self._wall_start is None:
self._wall_start = time.monotonic()
sim_time = self.data.time
wall_elapsed = time.monotonic() - self._wall_start
sleep_time = sim_time - wall_elapsed
if sleep_time > 0:
time.sleep(sleep_time)
self._viewer.render()
def _quat2rot(self, q):
return np.array(
[
[
2 * (q[0] * q[0] + q[1] * q[1]) - 1,
2 * (q[1] * q[2] - q[0] * q[3]),
2 * (q[1] * q[3] + q[0] * q[2]),
],
[
2 * (q[1] * q[2] + q[0] * q[3]),
2 * (q[0] * q[0] + q[2] * q[2]) - 1,
2 * (q[2] * q[3] - q[0] * q[1]),
],
[
2 * (q[1] * q[3] - q[0] * q[2]),
2 * (q[2] * q[3] + q[0] * q[1]),
2 * (q[0] * q[0] + q[3] * q[3]) - 1,
],
]
)
[docs]
def reset(self):
mujoco.mj_resetData(self.model, self.data)
self._wall_start = None
[docs]
def wait_for_close(self):
"""Keep the viewer open until the user closes the window.
If recording, save and close immediately instead of holding."""
if self.render and self._viewer is not None:
if self._viewer._record_path:
self._viewer.close()
else:
self._viewer.hold()
[docs]
def set_overlay(self, text):
"""Set overlay text displayed in the top-left corner."""
if self._viewer is not None:
self._viewer._overlay_text = text
[docs]
def add_marker_at(self, p, size=None, rgba=None, label=""):
"""Visual markers not yet supported with built-in viewer."""
pass
[docs]
def add_arrow_at(self, p, R, s, label="", color=None):
"""Visual markers not yet supported with built-in viewer."""
pass
from ..quadrotor.mujoco import QuadrotorMujoco as Quadrotor
from .multi_quad_cs_pointmass import MultiQuadrotorCSPointmass as MultiQuadrotorCSPointmass
from .multi_quad_rigidbody import MultiQuadRigidbody as MultiQuadRigidbody
from .quadrotor_comparison import QuadrotorComparison as QuadrotorComparison
from .quadrotor_cspayload_fleet import QuadrotorCsPayloadFleet as QuadrotorCsPayloadFleet
from .quadrotor_fleet import QuadrotorFleet as QuadrotorFleet
__all__ = [
"MujocoModel",
"MultiQuadrotorCSPointmass",
"MultiQuadRigidbody",
"Quadrotor",
"QuadrotorComparison",
"QuadrotorCsPayloadFleet",
"QuadrotorFleet",
]