Skip to content

Commit

Permalink
some updates
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanjzhao committed Aug 28, 2024
1 parent 7522d63 commit b8a734b
Show file tree
Hide file tree
Showing 124 changed files with 146 additions and 163,582 deletions.
78 changes: 41 additions & 37 deletions examples/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export class MuJoCoDemo {
this.simulation = new mujoco.Simulation(this.model, this.state);

// Define Random State Variables
this.params = { scene: initialScene, paused: false, useModel: true, help: false, ctrlnoiserate: 0.0, ctrlnoisestd: 0.0, keyframeNumber: 0 };
this.params = { scene: initialScene, paused: false, useModel: true, help: false, ctrlnoiserate: 0.0, ctrlnoisestd: 0.0, keyframeNumber: 0, ikEnabled: false};
this.mujoco_time = 0.0;
this.bodies = {}, this.lights = {};
this.tmpVec = new THREE.Vector3();
Expand Down Expand Up @@ -72,7 +72,6 @@ export class MuJoCoDemo {
this.loadPPOModel();

this.ikTarget = new THREE.Vector3();
this.ikEnabled = false;
this.ikJoints = ['shoulder1_left', 'shoulder2_left', 'elbow_left'];
this.ikEndEffector = 'hand_left';

Expand Down Expand Up @@ -123,29 +122,33 @@ export class MuJoCoDemo {

// should re-get pawel-diff
async loadPPOModel() {
let modelPath;
switch (this.params.scene) {
case 'humanoid.xml':
this.ppo_model = await tf.loadLayersModel('models/2_frame/model.json');
this.getObservation = () => this.getObservationSkeleton(2, 10, 6);
modelPath = 'models/humanoid_stand_6_frams_noise_1e-4/model.json';
this.getObservation = () => this.getObservationSkeleton(0, 10, 6);
break;
case 'blank':
this.ppo_model = await tf.loadLayersModel('models/cvals+2_frames/model.json');
modelPath = 'models/cvals+2_frames/model.json';
break;
case 'brax_humanoid.xml':
this.ppo_model = await tf.loadLayersModel('models/brax_humanoid_cvalless_just_stand/model.json');
modelPath = 'models/brax_humanoid_cvalless_just_stand/model.json';
this.getObservation = () => this.getObservationSkeleton(0, -1, -1);
break;
case 'brax_humanoidstandup.xml':
this.ppo_model = await tf.loadLayersModel('models/brax_humanoid_standup/model.json');
modelPath = 'models/brax_humanoid_standup/model.json';
this.getObservation = () => this.getObservationSkeleton(0, 20, 12);
break;
case 'dora/dora2.xml':
this.ppo_model = await tf.loadLayersModel('models/dora/model.json');
this.getObservation = () => this.getObservationSkeleton(0, 100, 72); // 172 diff total
modelPath = 'models/dora/model.json';
this.getObservation = () => this.getObservationSkeleton(0, 100, 72);
break;
default:
throw new Error(`Unknown Tensorflow.js model for XML path: ${this.params.scene}`);
}

console.log(`Loading model from path: ${modelPath}`);
this.ppo_model = await tf.loadLayersModel(modelPath);
}

getObservationSkeleton(qpos_slice, cinert_slice, cvel_slice) {
Expand Down Expand Up @@ -221,23 +224,11 @@ export class MuJoCoDemo {
case 'j':
this.moveActuator('elbow_', -stepSize);
break;
case 'i':
this.toggleIK();
break;
}
}

/* Inverse Kinematics */
toggleIK() {
this.ikEnabled = !this.ikEnabled;
if (this.ikEnabled) {
this.originalQpos = new Float64Array(this.simulation.qpos);
}
console.log(`IK ${this.ikEnabled ? 'enabled' : 'disabled'}`);
}

solveIK(id) {
if (!this.ikEnabled) return;
if (!this.params.ikEnabled) return;

const bodyName = this.getBodyNameById(id);
if (bodyName !== this.ikEndEffector) {
Expand Down Expand Up @@ -317,9 +308,6 @@ export class MuJoCoDemo {

// Add the iteration record to the overall IK control records
this.ikControlRecords.push(iterationRecord);

// Optionally, you can save the records to a file here or provide a method to do so
// this.saveIKControlRecords();
}

calculateJacobian(jointId) {
Expand Down Expand Up @@ -495,14 +483,21 @@ export class MuJoCoDemo {

if (!this.params["paused"]) {

// // reset to original state before paused
// if (this.pausedState) {
// this.simulation.qpos.set(this.pausedState.qpos);
// this.simulation.ctrl.set(this.pausedState.ctrl);
// this.pausedState = null;
// }

// Update originalQpos when unpaused
if (!this.originalQpos || this.originalQpos.length !== this.simulation.qpos.length) {
this.originalQpos = new Float64Array(this.simulation.qpos);
} else {
this.originalQpos.set(this.simulation.qpos);
}

if (this.ppo_model && this.params["useModel"]) {
if (this.ppo_model && this.params.useModel) {
const observationArray = this.getObservation();
const inputTensor = tf.tensor2d([observationArray]);
const resultTensor = this.ppo_model.predict(inputTensor);
Expand All @@ -515,18 +510,20 @@ export class MuJoCoDemo {
// Ensure the actuator index is within bounds
if (i < this.simulation.ctrl.length) {

let clippedValue = Math.max(-1, Math.min(1, data[i]));
// let clippedValue = Math.max(-1, Math.mi(1, data[i]));

let [min, max] = this.actuatorRanges[i];
// let [min, max] = this.actuatorRanges[i];

// Scale to fit between min and max
let newValue = min + (clippedValue + 1) * (max - min) / 2;
// // Scale to fit between min and maxn
// let newValue = min + (clippedValue + 1) * (max - min) / 2;

// Update the actuator value
this.simulation.ctrl[i] = newValue;
// // Update the actuator value
// this.simulation.ctrl[i] = newValue;

// Optionally, update the corresponding parameter
this.params[this.actuatorNames[i]] = newValue;
// // Optionally, update the corresponding parameter
// this.params[this.actuatorNames[i]] = newValue;
this.simulation.ctrl[i] = data[i];
this.params[this.actuatorNames[i]] = data[i];
} else {
console.error('Model output index out of bounds:', i);
}
Expand All @@ -538,8 +535,7 @@ export class MuJoCoDemo {
if (timeMS - this.mujoco_time > 35.0) { this.mujoco_time = timeMS; }
while (this.mujoco_time < timeMS) {


// updates states from dragging
// updates states from dragging
// Jitter the control state with gaussian random noise
if (this.params["ctrlnoisestd"] > 0.0) {
let rate = Math.exp(-timestep / Math.max(1e-10, this.params["ctrlnoiserate"]));
Expand Down Expand Up @@ -577,8 +573,16 @@ export class MuJoCoDemo {
}
} else if (this.params["paused"]) {

// // store the state on pause to restore to later
// if (!this.pausedState) {
// this.pausedState = {
// qpos: new Float64Array(this.simulation.qpos),
// ctrl: new Float64Array(this.simulation.ctrl)
// };
// }

this.dragStateManager.update(); // Update the world-space force origin
if (this.ikEnabled) {
if (this.params.ikEnabled) {
let dragged = this.dragStateManager.physicsObject;
if (dragged && dragged.bodyID) {
this.ikTarget.copy(this.dragStateManager.currentWorld);
Expand Down
111 changes: 95 additions & 16 deletions examples/mujocoUtils.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,30 @@ import * as THREE from 'three';
import { Reflector } from './utils/Reflector.js';
import { MuJoCoDemo } from './main.js';


export function printJointInfo(model, simulation) {
let textDecoder = new TextDecoder("utf-8");
let nullChar = textDecoder.decode(new ArrayBuffer(1));

console.log("Joint Information:");
for (let i = 0; i < model.njnt; i++) {
let name = textDecoder.decode(
model.names.subarray(
model.name_jntadr[i]
)
).split(nullChar)[0];

let qposAdr = model.jnt_qposadr[i];
let qposNum = model.jnt_type[i] === 0 ? 7 : model.jnt_type[i]; // 7 for free joint, otherwise use jnt_type

let positions = simulation.qpos.slice(qposAdr, qposAdr + qposNum);

console.log(`Joint ${i}: ${name}`);
console.log(` Position in qpos: ${qposAdr}`);
console.log(` Values: ${positions}`);
}
}

export async function reloadFunc() {
// Delete the old scene and load the new scene
this.scene.remove(this.scene.getObjectByName("MuJoCo Root"));
Expand All @@ -10,20 +34,11 @@ export async function reloadFunc() {

// console.log(this.model, this.state, this.simulation, this.bodies, this.lights);

// this.model.setOption("integrator", mujoco.INTEGRATOR_IMPLICIT);
// this.model.setOption("dt", 0.002); // Same as timestep in XML
// this.model.setOption("iterations", 50);
// this.model.setOption("solver", mujoco.SOLVER_NEWTON);
// this.model.setOption("tolerance", 1e-10);
// this.model.setOption("impratio", 1);
// this.model.setOption("noslip_iterations", 5);
// this.model.setOption("noslip_tolerance", 1e-6);
// this.model.setOption("mpr_iterations", 50);
// this.model.setOption("mpr_tolerance", 1e-6);
// this.model.setOption("apirate", 1);
// this.model.setOption("cone", mujoco.CONE_ELLIPTIC);
// this.model.setOption("jacobian", mujoco.JACOBIAN_DENSE);
// this.model.setOption("collision", mujoco.COLLISION_ALL);
// Log the current options
const options = this.model.getOptions();
console.log("Current model options:", options);

// this.pausedState = null;

// Initialize originalQpos with zeros
this.originalQpos = new Float64Array(this.simulation.qpos.length);
Expand Down Expand Up @@ -318,11 +333,16 @@ export function setupGUI(parentContext) {
pausedText.style.color = 'white';
pausedText.style.font = 'normal 18px Arial';
pausedText.innerHTML = 'pause';
pausedText.id = 'paused-text';
parentContext.container.appendChild(pausedText);
} else {
parentContext.container.removeChild(parentContext.container.lastChild);
const pausedText = document.getElementById('paused-text');
if (pausedText) {
parentContext.container.removeChild(pausedText);
}
}
});

document.addEventListener('keydown', (event) => {
if (event.code === 'Space') {
parentContext.params.paused = !parentContext.params.paused;
Expand All @@ -333,10 +353,67 @@ export function setupGUI(parentContext) {
actionInnerHTML += 'Play / Pause<br>';
keyInnerHTML += 'Space<br>';

// Add enable / disable model checkbox.
// Add IK enabled checkbox
parentContext.params.ikEnabled = false;
let ikEnabledCheckbox = simulationFolder.add(parentContext.params, 'ikEnabled').name('IK Enabled').listen();
ikEnabledCheckbox.onChange((value) => {
if (value) {
const ikEnabledText = document.createElement('div');
ikEnabledText.style.position = 'absolute';
ikEnabledText.style.top = '40px'; // Position it below the 'pause' text
ikEnabledText.style.left = '10px';
ikEnabledText.style.color = 'white';
ikEnabledText.style.font = 'normal 18px Arial';
ikEnabledText.innerHTML = 'IK enabled';
ikEnabledText.id = 'ik-enabled-text';
parentContext.container.appendChild(ikEnabledText);
} else {
const ikEnabledText = document.getElementById('ik-enabled-text');
if (ikEnabledText) {
parentContext.container.removeChild(ikEnabledText);
}
}
});

// Add keyboard shortcut for toggling IK
document.addEventListener('keydown', (event) => {
if (event.ctrlKey && event.code === 'KeyI') {
parentContext.params.ikEnabled = !parentContext.params.ikEnabled;
ikEnabledCheckbox.setValue(parentContext.params.ikEnabled);
event.preventDefault();
console.log(`IK ${parentContext.params.ikEnabled ? 'enabled' : 'disabled'}`);
}
});
actionInnerHTML += 'Toggle IK<br>';
keyInnerHTML += 'I<br>';

// Add model enabled checkbox
parentContext.params.useModel = true;
let modelEnabledCheckbox = simulationFolder.add(parentContext.params, 'useModel').name('Model Enabled').listen();
modelEnabledCheckbox.onChange((value) => {
if (value) {
const modelEnabledText = document.createElement('div');
modelEnabledText.style.position = 'absolute';
modelEnabledText.style.top = '70px'; // Position it below the 'IK enabled' text
modelEnabledText.style.left = '10px';
modelEnabledText.style.color = 'white';
modelEnabledText.style.font = 'normal 18px Arial';
modelEnabledText.innerHTML = 'Model enabled';
modelEnabledText.id = 'model-enabled-text';
parentContext.container.appendChild(modelEnabledText);
} else {
const modelEnabledText = document.getElementById('model-enabled-text');
if (modelEnabledText) {
parentContext.container.removeChild(modelEnabledText);
}
}
});

// Add keyboard shortcut for toggling model
document.addEventListener('keydown', (event) => {
if (event.ctrlKey && event.code === 'KeyM') {
parentContext.params.useModel = !parentContext.params.useModel;
modelEnabledCheckbox.updateDisplay();
event.preventDefault();
}
});
Expand Down Expand Up @@ -490,6 +567,8 @@ export async function loadSceneFromURL(mujoco, filename, parent) {
let state = parent.state;
let simulation = parent.simulation;

printJointInfo(parent.model, parent.simulation);

// Decode the null-terminated string names.
let textDecoder = new TextDecoder("utf-8");
let fullString = textDecoder.decode(model.names);
Expand Down
19 changes: 0 additions & 19 deletions examples/scenes/agility_cassie/LICENSE

This file was deleted.

28 changes: 0 additions & 28 deletions examples/scenes/agility_cassie/README.md

This file was deleted.

Loading

0 comments on commit b8a734b

Please sign in to comment.