Skip to content

Commit

Permalink
Merge pull request #5 from nathanjzhao/policy-in-wasm
Browse files Browse the repository at this point in the history
policy in wasm
  • Loading branch information
nathanjzhao authored Aug 29, 2024
2 parents b8a734b + 7bdfb6b commit 39af631
Show file tree
Hide file tree
Showing 10 changed files with 224 additions and 94 deletions.
215 changes: 135 additions & 80 deletions examples/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import { DragStateManager } from './utils/DragStateManager.js';
import { setupGUI, downloadExampleScenesFolder, loadSceneFromURL, getPosition, getQuaternion, toMujocoPos, standardNormal } from './mujocoUtils.js';
import load_mujoco from '../dist/mujoco_wasm.js';

import { LogStdLayer } from './utils/LogStdLayer.js';

// Load the MuJoCo Module
const mujoco = await load_mujoco();

Expand Down Expand Up @@ -119,62 +121,6 @@ export class MuJoCoDemo {
}
}


// should re-get pawel-diff
async loadPPOModel() {
let modelPath;
switch (this.params.scene) {
case 'humanoid.xml':
modelPath = 'models/humanoid_stand_6_frams_noise_1e-4/model.json';
this.getObservation = () => this.getObservationSkeleton(0, 10, 6);
break;
case 'blank':
modelPath = 'models/cvals+2_frames/model.json';
break;
case 'brax_humanoid.xml':
modelPath = 'models/brax_humanoid_cvalless_just_stand/model.json';
this.getObservation = () => this.getObservationSkeleton(0, -1, -1);
break;
case 'brax_humanoidstandup.xml':
modelPath = 'models/brax_humanoid_standup/model.json';
this.getObservation = () => this.getObservationSkeleton(0, 20, 12);
break;
case 'dora/dora2.xml':
modelPath = 'models/dora/model.json';
this.getObservation = () => this.getObservationSkeleton(0, 100, 72);
break;
default:
throw new Error(`Unknown Tensorflow.js model for XML path: ${this.params.scene}`);
}

console.log(`Loading model from path: ${modelPath}`);
this.ppo_model = await tf.loadLayersModel(modelPath);
}

getObservationSkeleton(qpos_slice, cinert_slice, cvel_slice) {
const qpos = this.simulation.qpos.slice(qpos_slice);
const qvel = this.simulation.qvel;
const cinert = cinert_slice !== -1 ? this.simulation.cinert.slice(cinert_slice) : [];
const cvel = cvel_slice !== -1 ? this.simulation.cvel.slice(cvel_slice) : [];
const qfrc_actuator = this.simulation.qfrc_actuator;

// console.log('qpos length:', qpos.length);
// console.log('qvel length:', qvel.length);
// console.log('cinert length:', cinert.length);
// console.log('cvel length:', cvel.length);
// console.log('qfrc_actuator length:', qfrc_actuator.length);

const obsComponents = [
...qpos,
...qvel,
...cinert,
...cvel,
...qfrc_actuator
];

return obsComponents;
}

handleKeyPress(event) {
const key = event.key.toLowerCase();
const stepSize = 0.1;
Expand Down Expand Up @@ -477,6 +423,92 @@ export class MuJoCoDemo {
this.renderer.setSize( window.innerWidth, window.innerHeight );
}


applyControls(timestep) {
for (let i = 0; i < this.actuatorNames.length; i++) {
const actuatorName = this.actuatorNames[i];
const jointIndex = this.model.actuator_trnid[2 * i];
const jointAddress = this.model.jnt_qposadr[jointIndex];

// Get current position and velocity
const currentPosition = this.simulation.qpos[jointAddress];
const currentVelocity = this.simulation.qvel[jointAddress];

// Get desired position from control input
const desiredPosition = this.params[actuatorName];

// PD control
const kp = 100; // Proportional gain
const kd = 10; // Derivative gain

const positionError = desiredPosition - currentPosition;
const velocityError = -currentVelocity; // Assuming desired velocity is 0

const control = kp * positionError + kd * velocityError;

// Apply the control force
this.simulation.qfrc_applied[jointAddress] += control;
console.log(this.simulation.qfrc_applied[jointAddress]);
}
}


// should re-get pawel-diff
async loadPPOModel() {
let modelPath;

switch (this.params.scene) {
case 'humanoid.xml':
modelPath = 'models/humanoid_stand_6_frams_noise_1e-4/model.json';
this.getObservation = () => this.getObservationSkeleton(0, 10, 6);
break;
case 'blank':
modelPath = 'models/cvals+2_frames/model.json';
break;
case 'brax_humanoid.xml':
modelPath = 'models/brax_humanoid_cvalless_just_stand/model.json';
this.getObservation = () => this.getObservationSkeleton(0, -1, -1);
break;
case 'brax_humanoidstandup.xml':
modelPath = 'models/brax_humanoid_standup/model.json';
this.getObservation = () => this.getObservationSkeleton(0, 20, 12);
break;
case 'dora/dora2.xml':
modelPath = 'models/dora/model.json';
this.getObservation = () => this.getObservationSkeleton(0, 100, 72);
break;
default:
throw new Error(`Unknown Tensorflow.js model for XML path: ${this.params.scene}`);
}

// Load the model with custom objects
this.ppo_model = await tf.loadLayersModel(modelPath);
}

getObservationSkeleton(qpos_slice, cinert_slice, cvel_slice) {
const qpos = this.simulation.qpos.slice(qpos_slice);
const qvel = this.simulation.qvel;
const cinert = cinert_slice !== -1 ? this.simulation.cinert.slice(cinert_slice) : [];
const cvel = cvel_slice !== -1 ? this.simulation.cvel.slice(cvel_slice) : [];
const qfrc_actuator = this.simulation.qfrc_actuator;

// console.log('qpos length:', qpos.length);
// console.log('qvel length:', qvel.length);
// console.log('cinert length:', cinert.length);
// console.log('cvel length:', cvel.length);
// console.log('qfrc_actuator length:', qfrc_actuator.length);

const obsComponents = [
...qpos,
...qvel,
...cinert,
...cvel,
...qfrc_actuator
];

return obsComponents;
}

// render loop
render(timeMS) {
this.controls.update();
Expand All @@ -500,35 +532,58 @@ export class MuJoCoDemo {
if (this.ppo_model && this.params.useModel) {
const observationArray = this.getObservation();
const inputTensor = tf.tensor2d([observationArray]);
const resultTensor = this.ppo_model.predict(inputTensor);

resultTensor.data().then(data => {
// console.log('Model output:', data);

console.log("Predicting...");
try {
const prediction = this.ppo_model.predict(inputTensor);

// Assuming the model output corresponds to actuator values
for (let i = 0; i < data.length; i++) {
// Ensure the actuator index is within bounds
if (!Array.isArray(prediction) || prediction.length !== 3) {
console.error('Unexpected prediction output:', prediction);
return;
}

const [actorMean, logStd, criticValue] = prediction;

console.log('Actor Mean:', actorMean.arraySync());
console.log('Log Std:', logStd.arraySync());
console.log('Critic Value:', criticValue.arraySync());

// Use tf.tidy to automatically dispose of intermediate tensors
tf.tidy(() => {
const stdDev = tf.exp(logStd);
const noise = tf.randomNormal(actorMean.shape);
const actions = actorMean.add(stdDev.mul(noise));

const actionData = actions.dataSync();

// Update actuator controls
for (let i = 0; i < actionData.length; i++) {
if (i < this.simulation.ctrl.length) {

// let clippedValue = Math.max(-1, Math.mi(1, data[i]));

// let [min, max] = this.actuatorRanges[i];

// // Scale to fit between min and maxn
// let newValue = min + (clippedValue + 1) * (max - min) / 2;

// // Update the actuator value
// this.simulation.ctrl[i] = newValue;

// // Optionally, update the corresponding parameter
// this.params[this.actuatorNames[i]] = newValue;
this.simulation.ctrl[i] = data[i];
this.params[this.actuatorNames[i]] = data[i];
let action = actionData[i];

// Clip action to [-1, 1] range
let clippedAction = Math.max(-1, Math.min(1, action));

// Scale action to actuator range
let [min, max] = this.actuatorRanges[i];
let newValue = min + (clippedAction + 1) * (max - min) / 2;

this.simulation.ctrl[i] = newValue;
this.params[this.actuatorNames[i]] = newValue;
} else {
console.error('Model output index out of bounds:', i);
}
}
});
}
});

// Dispose of tensors
inputTensor.dispose();
actorMean.dispose();
logStd.dispose();
criticValue.dispose();
} catch (error) {
console.error('Error during model prediction:', error);
}
}

let timestep = this.model.getOptions().timestep;
Expand Down
30 changes: 18 additions & 12 deletions examples/scenes/humanoid.xml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
-->

<mujoco model="Humanoid">
<!-- <option timestep="0.005"/> -->
<option timestep="0.005" solver="CG"/>
<option timestep="0.005" iterations="1" ls_iterations="4">
<flag eulerdamp="disable"/>
</option>

<visual>
<map force="0.1" zfar="30"/>
Expand All @@ -27,7 +28,7 @@

<asset>
<texture type="skybox" builtin="gradient" rgb1=".3 .5 .7" rgb2="0 0 0" width="32" height="512"/>
<texture name="body" type="cube" builtin="flat" mark="cross" width="127" height="1278" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" markrgb="1 1 1" random="0.01"/>
<texture name="body" type="cube" builtin="flat" mark="cross" width="128" height="128" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" markrgb="1 1 1" random="0.01"/>
<material name="body" texture="body" texuniform="true" rgba="0.8 0.6 .4 1"/>
<texture name="grid" type="2d" builtin="checker" width="512" height="512" rgb1=".1 .2 .3" rgb2=".2 .3 .4"/>
<material name="grid" texture="grid" texrepeat="1 1" texuniform="true" reflectance=".2"/>
Expand All @@ -38,7 +39,8 @@
<default class="body">

<!-- geoms -->
<geom type="capsule" condim="1" friction=".7" solimp=".9 .99 .003" solref=".015 1" material="body"/>
<!-- TODO(robotics-simulation): support condim=1 for humanoid capsules. -->
<geom type="capsule" condim="3" friction=".7" solimp=".9 .99 .003" solref=".015 1" material="body" contype="0" conaffinity="0"/>
<default class="thigh">
<geom size=".06"/>
</default>
Expand All @@ -65,9 +67,9 @@
</default>

<!-- joints -->
<joint type="hinge" damping=".4" stiffness="2" armature=".01" limited="true" solimplimit="0 .99 .01"/>
<joint type="hinge" damping=".2" stiffness="1" armature=".01" limited="true" solimplimit="0 .99 .01"/>
<default class="joint_big">
<joint damping="10" stiffness="20"/>
<joint damping="5" stiffness="10"/>
<default class="hip_x">
<joint range="-30 10"/>
</default>
Expand All @@ -78,7 +80,7 @@
<joint axis="0 1 0" range="-150 20"/>
</default>
<default class="joint_big_stiff">
<joint stiffness="40"/>
<joint stiffness="20"/>
</default>
</default>
<default class="knee">
Expand All @@ -87,10 +89,10 @@
<default class="ankle">
<joint range="-50 50"/>
<default class="ankle_y">
<joint pos="0 0 .08" axis="0 1 0" stiffness="12"/>
<joint pos="0 0 .08" axis="0 1 0" stiffness="6"/>
</default>
<default class="ankle_x">
<joint pos="0 0 .04" stiffness="6"/>
<joint pos="0 0 .04" stiffness="3"/>
</default>
</default>
<default class="shoulder">
Expand Down Expand Up @@ -187,9 +189,13 @@
<contact>
<exclude body1="waist_lower" body2="thigh_right"/>
<exclude body1="waist_lower" body2="thigh_left"/>
<pair geom1="foot1_left" geom2="floor"/>
<pair geom1="foot1_right" geom2="floor"/>
<pair geom1="foot2_left" geom2="floor"/>
<pair geom1="foot2_right" geom2="floor"/>
</contact>

<tendon>
<!-- <tendon>
<fixed name="hamstring_right" limited="true" range="-0.3 2">
<joint joint="hip_y_right" coef=".5"/>
<joint joint="knee_right" coef="-.5"/>
Expand All @@ -198,7 +204,7 @@
<joint joint="hip_y_left" coef=".5"/>
<joint joint="knee_left" coef="-.5"/>
</fixed>
</tendon>
</tendon> -->

<actuator>
<motor name="abdomen_y" gear="40" joint="abdomen_y"/>
Expand Down Expand Up @@ -247,4 +253,4 @@
-0.08 -0.01 -0.37 -0.685 -0.35 -0.09
0.109 -0.067 -0.7 -0.05 0.12 0.16"/>
</keyframe>
</mujoco>
</mujoco>
Loading

0 comments on commit 39af631

Please sign in to comment.