Skip to content

Commit

Permalink
support for uploading custom models
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanjzhao committed Aug 20, 2024
1 parent a99db95 commit 8e13ad4
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 15 deletions.
9 changes: 5 additions & 4 deletions examples/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,12 @@ export class MuJoCoDemo {
this.ppo_model = await tf.loadLayersModel('models/brax_humanoid_standup/model.json');
this.getObservation = () => this.getObservationSkeleton(0, 20, 12);
break;
case 'dora/dora2.xml':
this.ppo_model = await tf.loadLayersModel('models/dora/model.json');
this.getObservation = () => this.getObservationSkeleton(0, 100, 72); // 172 diff total
break;
default:
throw new Error(`Unknown model path: ${this.params.scene}`);
throw new Error(`Unknown Tensorflow.js model for XML path: ${this.params.scene}`);
}
}

Expand Down Expand Up @@ -265,9 +269,6 @@ export class MuJoCoDemo {
});
}

// console.log(this.model)
// console.log(this.model.getOptions().timestep);

let timestep = this.model.getOptions().timestep;
if (timeMS - this.mujoco_time > 35.0) { this.mujoco_time = timeMS; }
while (this.mujoco_time < timeMS) {
Expand Down
138 changes: 127 additions & 11 deletions examples/mujocoUtils.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,136 @@ export function setupGUI(parentContext) {
parentContext.controls.target.set(0, 0.7, 0);
parentContext.controls.update(); });

// Add scene selection dropdown.
let reload = reloadFunc.bind(parentContext);
parentContext.gui.add(parentContext.params, 'scene', {
"Humanoid": "humanoid.xml",
"Brax Humanoid": "brax_humanoid.xml",

parentContext.allScenes = {
"Humanoid": "humanoid.xml",
"Brax Humanoid": "brax_humanoid.xml",
"Brax Humanoid Standup": "brax_humanoidstandup.xml",
"Dora": "dora/dora2.xml",
"Hammock": "hammock.xml",
// "Mug": "mug.xml",
// "Stompy Legs": "stompy/legs.xml",
// "Cassie": "agility_cassie/scene.xml",
// "Hammock": "hammock.xml", "Balloons": "balloons.xml", "Hand": "shadow_hand/scene_right.xml",
// "Flag": "flag.xml", "Mug": "mug.xml", "Tendon": "model_with_tendon.xml"
}).name('Example Scene').onChange(reload);
};

// Add scene selection dropdown.
let reload = reloadFunc.bind(parentContext);
let sceneDropdown = parentContext.gui.add(parentContext.params, 'scene', parentContext.allScenes).name('Example Scene').onChange(reload);

// Add upload button
let uploadButton = {
upload: function() {
let input = document.createElement('input');
input.type = 'file';
input.multiple = true;
input.accept = '.xml,.obj,.stl';
input.onchange = async function(event) {
let files = event.target.files;
let xmlFile = null;
let meshFiles = [];
let newSceneName = '';

for (let file of files) {
if (file.name.endsWith('.xml')) {
xmlFile = file;
newSceneName = file.name.split('.')[0];
} else {
meshFiles.push(file);
}
}

if (!xmlFile) {
alert('Please include an XML file.');
return;
}

// Create 'working' directory if it doesn't exist
if (!parentContext.mujoco.FS.analyzePath('/working').exists) {
parentContext.mujoco.FS.mkdir('/working');
}

// Write XML file
let xmlContent = await xmlFile.arrayBuffer();
parentContext.mujoco.FS.writeFile(`/working/${xmlFile.name}`, new Uint8Array(xmlContent));

// Write mesh files
for (let meshFile of meshFiles) {
let meshContent = await meshFile.arrayBuffer();
parentContext.mujoco.FS.writeFile(`/working/${meshFile.name}`, new Uint8Array(meshContent));
}

// Update scene dropdown
parentContext.allScenes[newSceneName] = xmlFile.name;
updateSceneDropdown(sceneDropdown, parentContext.allScenes);

parentContext.params.scene = xmlFile.name;
sceneDropdown.updateDisplay();

console.log(`Uploaded ${xmlFile.name} and ${meshFiles.length} mesh file(s)`);
// alert(`Uploaded ${xmlFile.name} and ${meshFiles.length} mesh file(s)`);


// Trigger a reload of the scene
reload();
};
input.click();
}
};

parentContext.gui.add(uploadButton, 'upload').name('Upload Scene');


function updateSceneDropdown(dropdown, scenes) {
// Store the current onChange function
let onChangeFunc = dropdown.__onChange;

// Remove all options
if (dropdown.__select && dropdown.__select.options) {
dropdown.__select.options.length = 0;
}

console.log(scenes)
dropdown.__select = document.createElement('select');

// Add new options
for (let [name, file] of Object.entries(scenes)) {
let option = document.createElement('option');
option.text = name;
option.value = file;
dropdown.__select.add(option);
}

// Restore the onChange function
dropdown.__onChange = onChangeFunc;
}

// function updateSceneDropdown(dropdown, scenes) {
// // Remove all options from the underlying select element
// if (dropdown.__select && dropdown.__select.options) {
// dropdown.__select.options.length = 0;
// }
// console.log("scenes", scenes)

// // Update the dropdown's __select property
// dropdown.__select = document.createElement('select');

// // Add new options
// for (let [name, file] of Object.entries(scenes)) {
// let option = document.createElement('option');
// option.text = name;
// option.value = file;
// dropdown.__select.add(option);
// }

// // Update the controller's object and property
// dropdown.object = scenes;
// dropdown.property = 'scene';

// // Rebuild the DOM elements
// dropdown.domElement.removeChild(dropdown.domElement.childNodes[0]);
// dropdown.domElement.appendChild(dropdown.__select);

// // Update the display
// dropdown.updateDisplay();
// dropdown.onChange(reload);
// }

// Add a help menu.
// Parameters:
Expand Down
Binary file added models/dora/group1-shard1of1.bin
Binary file not shown.
1 change: 1 addition & 0 deletions models/dora/model.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"format": "layers-model", "generatedBy": "keras v2.16.0", "convertedBy": "TensorFlow.js Converter v4.20.0", "modelTopology": {"keras_version": "2.16.0", "backend": "tensorflow", "model_config": {"class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": [null, 235], "dtype": "float32", "sparse": false, "ragged": false, "name": "input_layer"}}, {"class_name": "Dense", "config": {"name": "actor_dense1", "trainable": true, "dtype": "float32", "units": 512, "activation": "elu", "use_bias": true, "kernel_initializer": {"module": "keras.initializers", "class_name": "GlorotUniform", "config": {"seed": null}, "registered_name": null}, "bias_initializer": {"module": "keras.initializers", "class_name": "Zeros", "config": {}, "registered_name": null}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Dense", "config": {"name": "actor_dense2", "trainable": true, "dtype": "float32", "units": 256, "activation": "elu", "use_bias": true, "kernel_initializer": {"module": "keras.initializers", "class_name": "GlorotUniform", "config": {"seed": null}, "registered_name": null}, "bias_initializer": {"module": "keras.initializers", "class_name": "Zeros", "config": {}, "registered_name": null}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Dense", "config": {"name": "actor_dense3", "trainable": true, "dtype": "float32", "units": 128, "activation": "elu", "use_bias": true, "kernel_initializer": {"module": "keras.initializers", "class_name": "GlorotUniform", "config": {"seed": null}, "registered_name": null}, "bias_initializer": {"module": "keras.initializers", "class_name": "Zeros", "config": {}, "registered_name": null}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Dense", "config": {"name": "actor_output", "trainable": true, "dtype": "float32", "units": 12, "activation": "linear", "use_bias": true, "kernel_initializer": {"module": "keras.initializers", "class_name": "GlorotUniform", "config": {"seed": null}, "registered_name": null}, "bias_initializer": {"module": "keras.initializers", "class_name": "Zeros", "config": {}, "registered_name": null}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}]}}}, "weightsManifest": [{"paths": ["group1-shard1of1.bin"], "weights": [{"name": "actor_dense1/kernel", "shape": [235, 512], "dtype": "float32"}, {"name": "actor_dense1/bias", "shape": [512], "dtype": "float32"}, {"name": "actor_dense2/kernel", "shape": [512, 256], "dtype": "float32"}, {"name": "actor_dense2/bias", "shape": [256], "dtype": "float32"}, {"name": "actor_dense3/kernel", "shape": [256, 128], "dtype": "float32"}, {"name": "actor_dense3/bias", "shape": [128], "dtype": "float32"}, {"name": "actor_output/kernel", "shape": [128, 12], "dtype": "float32"}, {"name": "actor_output/bias", "shape": [12], "dtype": "float32"}]}]}

0 comments on commit 8e13ad4

Please sign in to comment.