Skip to content
This repository was archived by the owner on Apr 15, 2018. It is now read-only.

Require Java Serializable types to be whitelisted #21

Merged
merged 3 commits into from
Feb 2, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package io.atomix.catalyst.serializer;

import io.atomix.catalyst.serializer.util.CatalystSerializableSerializer;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
Expand Down
204 changes: 98 additions & 106 deletions serializer/src/main/java/io/atomix/catalyst/serializer/Serializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
package io.atomix.catalyst.serializer;

import io.atomix.catalyst.buffer.*;
import io.atomix.catalyst.serializer.util.PooledTypeSerializer;
import io.atomix.catalyst.util.ReferenceCounted;

import java.io.*;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.*;

/**
Expand All @@ -31,7 +33,7 @@
* <p>
* Serializable objects must either provide a {@link TypeSerializer}. implement {@link CatalystSerializable}, or implement
* {@link java.io.Externalizable}. For efficiency, serializable objects may implement {@link ReferenceCounted}
* or provide a {@link PooledSerializer} that reuses objects during deserialization.
* or provide a {@link PooledTypeSerializer} that reuses objects during deserialization.
* Catalyst will automatically deserialize {@link ReferenceCounted} types using an object pool.
* <p>
* Serialization via this class is not thread safe.
Expand All @@ -49,11 +51,11 @@ public class Serializer implements Cloneable {
private static final byte TYPE_ID_24 = 0x04;
private static final byte TYPE_ID_32 = 0x05;
private static final byte TYPE_CLASS = 0x07;
private static final byte TYPE_SERIALIZABLE = 0x08;
private SerializerRegistry registry;
SerializerRegistry registry;
private Map<Class<?>, TypeSerializer<?>> serializers = new HashMap<>();
private Map<String, Class<?>> types = new HashMap<>();
private final BufferAllocator allocator;
private boolean whitelistRequired = true;

/**
* Creates a new serializer instance with a default {@link UnpooledHeapAllocator}.
Expand Down Expand Up @@ -182,6 +184,43 @@ public Serializer(BufferAllocator allocator, Collection<SerializableTypeResolver
registry = new SerializerRegistry(resolvers);
}

/**
* Enables whitelisting for serializable types.
* <p>
* When whitelisting is enabled, only types that are registered with the {@link SerializerRegistry}
* can be serialized and deserialized, and classes will never be loaded by class names. This prevents
* certain types of attacks in untrusted networks.
*
* @return The serializer.
*/
public Serializer enableWhitelist() {
this.whitelistRequired = true;
return this;
}

/**
* Disables whitelisting for serializable types.
* <p>
* When whitelisting is disabled, types that are not registered may be serialized and deserialized
* by this serializer. This can pose a security risk in an untrusted network. It's recommended that
* users enable whitelisting and register serializable classes.
*
* @return The serializer.
*/
public Serializer disableWhitelist() {
this.whitelistRequired = false;
return this;
}

/**
* Indicates whether whitelisting is enabled for the serializer.
*
* @return Whether whitelisting is enabled for the serializer.
*/
public boolean isWhitelistRequired() {
return whitelistRequired;
}

/**
* Resolves serializable types with the given resolver.
* <p>
Expand Down Expand Up @@ -419,7 +458,7 @@ public <T> T copy(T object) {
private <T> TypeSerializer<T> getSerializer(Class<T> type) {
TypeSerializer<T> serializer = (TypeSerializer<T>) serializers.get(type);
if (serializer == null) {
TypeSerializerFactory factory = registry.lookup(type);
TypeSerializerFactory factory = registry.factory(type);
if (factory != null) {
serializer = (TypeSerializer<T>) factory.createSerializer(type);
serializers.put(type, serializer);
Expand Down Expand Up @@ -563,38 +602,33 @@ public <T> BufferOutput<?> writeObject(T object, BufferOutput<?> buffer) {
if (type.getEnclosingClass() != null && type.getEnclosingClass().isEnum())
type = type.getEnclosingClass();

Integer typeId = registry.ids().get(type);
if (typeId != null) {
TypeSerializer<?> serializer = getSerializer(type);
// Look up the serializer for the given object type.
TypeSerializer<?> serializer = getSerializer(type);

if (serializer == null) {
if (object instanceof Serializable) {
return writeSerializable(object, buffer);
}
throw new SerializationException("cannot serialize unregistered type: " + type);
}
// If no serializer was found, throw a serialization exception.
if (serializer == null) {
throw new SerializationException("cannot serialize unregistered type: " + type);
}

if (typeId >= 0) {
if (typeId <= MAX_ID_8) {
return writeById8(typeId, object, buffer, serializer);
} else if (typeId <= MAX_ID_16) {
return writeById16(typeId, object, buffer, serializer);
} else if (typeId <= MAX_ID_24) {
return writeById24(typeId, object, buffer, serializer);
}
}
return writeById32(typeId, object, buffer, serializer);
} else {
TypeSerializer<?> serializer = getSerializer(type);

if (serializer == null) {
if (object instanceof Serializable) {
return writeSerializable(object, buffer);
}
throw new SerializationException("cannot serialize unregistered type: " + type);
}
// Lookup the serializable type ID for the type.
int typeId = registry.id(type);

// If no type ID was registered, write the object with the class name.
if (typeId == 0) {
return writeByClass(type, object, buffer, serializer);
}

// Write the serializable type ID in the most compact form possible.
if (typeId >= 0) {
if (typeId <= MAX_ID_8) {
return writeById8(typeId, object, buffer, serializer);
} else if (typeId <= MAX_ID_16) {
return writeById16(typeId, object, buffer, serializer);
} else if (typeId <= MAX_ID_24) {
return writeById24(typeId, object, buffer, serializer);
}
}
return writeById32(typeId, object, buffer, serializer);
}

/**
Expand Down Expand Up @@ -692,33 +726,14 @@ private <T> BufferOutput<?> writeById32(int id, T object, BufferOutput<?> buffer
* @param <T> The object type.
* @return The written buffer.
*/
@SuppressWarnings({ "unchecked", "rawtypes" })
@SuppressWarnings("unchecked")
private <T> BufferOutput<?> writeByClass(Class<?> type, T object, BufferOutput<?> buffer, TypeSerializer serializer) {
if (whitelistRequired)
throw new SerializationException("cannot serialize unregistered type: " + type);
serializer.write(object, buffer.writeByte(TYPE_CLASS).writeUTF8(type.getName()), this);
return buffer;
}

/**
* Writes a serializable object to the given buffer.
*
* @param serializable The object to write to the buffer.
* @param buffer The buffer to which to write the object.
* @param <T> The object type.
* @return The written buffer.
*/
private <T> BufferOutput<?> writeSerializable(T serializable, BufferOutput<?> buffer) {
buffer.writeByte(TYPE_SERIALIZABLE);
try (ByteArrayOutputStream os = new ByteArrayOutputStream(); ObjectOutputStream out = new ObjectOutputStream(os)) {
out.writeObject(serializable);
out.flush();
byte[] bytes = os.toByteArray();
buffer.writeUnsignedShort(bytes.length).write(bytes);
} catch (IOException e) {
throw new SerializationException("failed to serialize Java object", e);
}
return buffer;
}

/**
* Reads an object from the given input stream.
* <p>
Expand Down Expand Up @@ -839,8 +854,6 @@ public <T> T readObject(BufferInput<?> buffer) {
return readById32(buffer);
case TYPE_CLASS:
return readByClass(buffer);
case TYPE_SERIALIZABLE:
return readSerializable(buffer);
default:
throw new SerializationException("unknown serializable type");
}
Expand All @@ -865,14 +878,8 @@ private Buffer readBuffer(BufferInput<?> buffer) {
* @param <T> The object type.
* @return The read object.
*/
@SuppressWarnings("unchecked")
private <T> T readById8(BufferInput<?> buffer) {
int id = buffer.readUnsignedByte();
Class<T> type = (Class<T>) registry.types().get(id);
TypeSerializer<T> serializer = getSerializer(type);
if (type == null || serializer == null)
throw new SerializationException("cannot deserialize: unknown type");
return serializer.read(type, buffer, this);
return readById(buffer.readUnsignedByte(), buffer);
}

/**
Expand All @@ -882,47 +889,50 @@ private <T> T readById8(BufferInput<?> buffer) {
* @param <T> The object type.
* @return The read object.
*/
@SuppressWarnings("unchecked")
private <T> T readById16(BufferInput<?> buffer) {
int id = buffer.readUnsignedShort();
Class<T> type = (Class<T>) registry.types().get(id);
TypeSerializer<T> serializer = getSerializer(type);
if (type == null || serializer == null)
throw new SerializationException("cannot deserialize: unknown type");
return serializer.read(type, buffer, this);
return readById(buffer.readUnsignedShort(), buffer);
}


/**
* Reads a serializable object.
*
* @param buffer The buffer from which to read the object.
* @param <T> The object type.
* @return The read object.
*/
@SuppressWarnings("unchecked")
private <T> T readById24(BufferInput<?> buffer) {
int id = buffer.readUnsignedMedium();
Class<T> type = (Class<T>) registry.types().get(id);
TypeSerializer<T> serializer = getSerializer(type);
if (type == null || serializer == null)
throw new SerializationException("cannot deserialize: unknown type");
return serializer.read(type, buffer, this);
return readById(buffer.readUnsignedMedium(), buffer);
}

/**
* Reads a serializable object.
*
* @param buffer The buffer from which to read the object.
* @param <T> The object type.
* @return The read object.
*/
@SuppressWarnings("unchecked")
private <T> T readById32(BufferInput<?> buffer) {
int id = buffer.readInt();
Class<T> type = (Class<T>) registry.types().get(id);
return readById(buffer.readInt(), buffer);
}

/**
* Reads a serializable object.
*
* @param id The serializable type ID.
* @param buffer The buffer from which to read the object.
* @param <T> The object type.
* @return The read object.
*/
@SuppressWarnings("unchecked")
private <T> T readById(int id, BufferInput<?> buffer) {
Class<T> type = (Class<T>) registry.type(id);
if (type == null)
throw new SerializationException("cannot deserialize: unknown type");

TypeSerializer<T> serializer = getSerializer(type);
if (type == null || serializer == null)
if (serializer == null)
throw new SerializationException("cannot deserialize: unknown type");

return serializer.read(type, buffer, this);
}

Expand All @@ -936,6 +946,9 @@ private <T> T readById32(BufferInput<?> buffer) {
@SuppressWarnings("unchecked")
private <T> T readByClass(BufferInput<?> buffer) {
String name = buffer.readUTF8();
if (whitelistRequired)
throw new SerializationException("cannot deserialize unregistered type: " + name);

Class<T> type = (Class<T>) types.get(name);
if (type == null) {
try {
Expand All @@ -950,30 +963,9 @@ private <T> T readByClass(BufferInput<?> buffer) {

TypeSerializer<T> serializer = getSerializer(type);
if (serializer == null)
throw new SerializationException("cannot deserialize: unknown type");
return serializer.read(type, buffer, this);
}
throw new SerializationException("cannot deserialize unregistered type: " + name);

/**
* Reads a Java serializable object.
*
* @param buffer The buffer from which to read the object.
* @param <T> The object type.
* @return The read object.
*/
@SuppressWarnings("unchecked")
private <T> T readSerializable(BufferInput<?> buffer) {
byte[] bytes = new byte[buffer.readUnsignedShort()];
buffer.read(bytes);
try (ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(bytes))) {
try {
return (T) in.readObject();
} catch (ClassNotFoundException e) {
throw new SerializationException("failed to deserialize Java object", e);
}
} catch (IOException e) {
throw new SerializationException("failed to deserialize Java object", e);
}
return serializer.read(type, buffer, this);
}

@Override
Expand Down
Loading