Skip to content

Commit

Permalink
Issue #2175 Allow unmarshalling of objects of permitted classes only …
Browse files Browse the repository at this point in the history
…for StateRepository
  • Loading branch information
rahul-mittal committed Dec 18, 2017
1 parent 98d91f1 commit 557c5af
Show file tree
Hide file tree
Showing 15 changed files with 541 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,32 @@
package org.ehcache.spi.persistence;

import java.io.Serializable;
import java.util.function.Predicate;

/**
* A repository allowing to preserve state in the context of a {@link org.ehcache.Cache}.
*/
public interface StateRepository {

/**
* Gets a named state holder rooted in the current {@code StateRepository}.
* <p>
* If the state holder existed already, it is returned with its content fully available.
*
* @deprecated Replaced by {@link #getPersistentStateHolder(String, Class, Class, Predicate, ClassLoader)} that takes in a Predicate that authorizes a class for deserialization
*
* @param name the state holder name
* @param keyClass concrete key type
* @param valueClass concrete value type
* @param <K> the key type, must be {@code Serializable}
* @param <V> the value type, must be {@code Serializable}
* @return a state holder
*/
@Deprecated
default <K extends Serializable, V extends Serializable> StateHolder<K, V> getPersistentStateHolder(String name, Class<K> keyClass, Class<V> valueClass) {
return getPersistentStateHolder(name, keyClass, valueClass, c -> true, null);
}

/**
* Gets a named state holder rooted in the current {@code StateRepository}.
* <p>
Expand All @@ -33,7 +53,13 @@ public interface StateRepository {
* @param valueClass concrete value type
* @param <K> the key type, must be {@code Serializable}
* @param <V> the value type, must be {@code Serializable}
* @param isClassPermitted Predicate that determines whether a class is authorized for deserialization as part of key or value deserialization
* @param classLoader class loader used at the time of deserialization of key and value
* @return a state holder
*/
<K extends Serializable, V extends Serializable> StateHolder<K, V> getPersistentStateHolder(String name, Class<K> keyClass, Class<V> valueClass);
<K extends Serializable, V extends Serializable> StateHolder<K, V> getPersistentStateHolder(String name,
Class<K> keyClass,
Class<V> valueClass,
Predicate<Class<?>> isClassPermitted,
ClassLoader classLoader);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.ehcache.spi.persistence.StateRepository;

import java.io.Serializable;
import java.util.function.Predicate;

/**
* ClusterStateRepository
Expand All @@ -39,7 +40,12 @@ class ClusterStateRepository implements StateRepository {
}

@Override
public <K extends Serializable, V extends Serializable> StateHolder<K, V> getPersistentStateHolder(String name, Class<K> keyClass, Class<V> valueClass) {
return new ClusteredStateHolder<>(clusterCacheIdentifier.getId(), composedId + "-" + name, clientEntity, keyClass, valueClass);
public <K extends Serializable, V extends Serializable> StateHolder<K, V> getPersistentStateHolder(String name,
Class<K> keyClass,
Class<V> valueClass,
Predicate<Class<?>> isClassPermitted,
ClassLoader classLoader) {
return new ClusteredStateHolder<>(clusterCacheIdentifier.getId(), composedId + "-" + name, clientEntity,
keyClass, valueClass, isClassPermitted, classLoader);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeoutException;
import java.util.function.Predicate;

import static org.ehcache.clustered.client.internal.service.ValueCodecFactory.getCodecForClass;

Expand All @@ -39,10 +40,12 @@ public class ClusteredStateHolder<K, V> implements StateHolder<K, V> {
private final ValueCodec<K> keyCodec;
private final ValueCodec<V> valueCodec;

public ClusteredStateHolder(final String cacheId, final String mapId, final ClusterTierClientEntity entity, Class<K> keyClass, Class<V> valueClass) {
public ClusteredStateHolder(final String cacheId, final String mapId, final ClusterTierClientEntity entity,
Class<K> keyClass, Class<V> valueClass,
Predicate<Class<?>> isClassPermittted, ClassLoader classLoader) {
this.keyClass = keyClass;
this.keyCodec = getCodecForClass(keyClass);
this.valueCodec = getCodecForClass(valueClass);
this.keyCodec = getCodecForClass(keyClass, isClassPermittted, classLoader);
this.valueCodec = getCodecForClass(valueClass, isClassPermittted, classLoader);
this.messageFactory = new StateRepositoryMessageFactory(cacheId, mapId);
this.entity = entity;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.ehcache.clustered.client.internal.service;

import org.ehcache.clustered.common.internal.store.FilteredObjectInputStream;
import org.ehcache.clustered.common.internal.store.ValueWrapper;

import java.io.ByteArrayInputStream;
Expand All @@ -24,12 +25,13 @@
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.function.Predicate;

/**
* ValueCodecFactory
*/
class ValueCodecFactory {
static <T> ValueCodec<T> getCodecForClass(Class<T> clazz) {
static <T> ValueCodec<T> getCodecForClass(Class<T> clazz, Predicate<Class<?>> isClassPermitted, ClassLoader classLoader) {
if (!Serializable.class.isAssignableFrom(clazz)) {
throw new IllegalArgumentException("The provided type is invalid as it is not Serializable " + clazz);
}
Expand All @@ -39,7 +41,7 @@ static <T> ValueCodec<T> getCodecForClass(Class<T> clazz) {
|| clazz.isPrimitive() || String.class.equals(clazz)) {
return new IdentityCodec<>();
} else {
return new SerializationWrapperCodec<>();
return new SerializationWrapperCodec<>(isClassPermitted, classLoader);
}
}

Expand All @@ -57,6 +59,15 @@ public T decode(Object input) {
}

private static class SerializationWrapperCodec<T> implements ValueCodec<T> {

private final Predicate<Class<?>> isClassPermitted;
private final ClassLoader classLoader;

public SerializationWrapperCodec(Predicate<Class<?>> isClassPermitted, ClassLoader classLoader) {
this.isClassPermitted = isClassPermitted;
this.classLoader = classLoader;
}

@Override
public Object encode(T input) {
if (input == null) {
Expand All @@ -78,18 +89,13 @@ public T decode(Object input) {
}
ValueWrapper data = (ValueWrapper) input;
ByteArrayInputStream bais = new ByteArrayInputStream(data.getValue());
try {
try (ObjectInputStream ois = new ObjectInputStream(bais)) {
@SuppressWarnings("unchecked")
T result = (T) ois.readObject();
return result;
} catch (ClassNotFoundException e) {
throw new RuntimeException("Could not load class", e);
}
} catch(IOException e) {
// ignore
try (ObjectInputStream ois = new FilteredObjectInputStream(bais, isClassPermitted, classLoader)) {
@SuppressWarnings("unchecked")
T result = (T) ois.readObject();
return result;
} catch (ClassNotFoundException | IOException e) {
throw new RuntimeException("Could not load class", e);
}
throw new AssertionError("Cannot reach here!");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
/*
* Copyright Terracotta, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.ehcache.clustered.client.internal.service;

import org.ehcache.clustered.client.config.ClusteringServiceConfiguration;
import org.ehcache.clustered.client.config.builders.ClusteringServiceConfigurationBuilder;
import org.ehcache.clustered.client.internal.ClusterTierManagerClientEntityService;
import org.ehcache.clustered.client.internal.UnitTestConnectionService;
import org.ehcache.clustered.client.internal.lock.VoltronReadWriteLockEntityClientService;
import org.ehcache.clustered.client.internal.store.ClusterTierClientEntityService;
import org.ehcache.clustered.client.internal.store.ServerStoreProxy;
import org.ehcache.clustered.client.internal.store.SimpleClusterTierClientEntity;
import org.ehcache.clustered.client.service.ClusteringService;
import org.ehcache.clustered.common.Consistency;
import org.ehcache.clustered.lock.server.VoltronReadWriteLockServerEntityService;
import org.ehcache.clustered.server.ClusterTierManagerServerEntityService;
import org.ehcache.clustered.server.store.ClusterTierServerEntityService;
import org.ehcache.core.config.BaseCacheConfiguration;
import org.ehcache.core.internal.store.StoreConfigurationImpl;
import org.ehcache.spi.persistence.StateHolder;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.terracotta.offheapresource.OffHeapResourcesProvider;
import org.terracotta.offheapresource.config.MemoryUnit;
import org.terracotta.passthrough.PassthroughClusterControl;
import org.terracotta.passthrough.PassthroughTestHelpers;

import java.io.Serializable;
import java.lang.reflect.Field;
import java.net.URI;
import java.util.Arrays;

import static org.ehcache.clustered.client.config.builders.ClusteredResourcePoolBuilder.clusteredDedicated;
import static org.ehcache.clustered.client.internal.UnitTestConnectionService.getOffheapResourcesType;
import static org.ehcache.config.Eviction.noAdvice;
import static org.ehcache.config.builders.ResourcePoolsBuilder.newResourcePoolsBuilder;
import static org.ehcache.expiry.Expirations.noExpiration;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;

public class StateRepositoryWhitelistingTest {

private PassthroughClusterControl clusterControl;
private static String STRIPENAME = "stripe";
private static String STRIPE_URI = "passthrough://" + STRIPENAME;
private ClusteringService service;
ClusterStateRepository stateRepository;

@Before
public void setUp() throws Exception {
this.clusterControl = PassthroughTestHelpers.createActiveOnly(STRIPENAME,
server -> {
server.registerServerEntityService(new ClusterTierManagerServerEntityService());
server.registerClientEntityService(new ClusterTierManagerClientEntityService());
server.registerServerEntityService(new ClusterTierServerEntityService());
server.registerClientEntityService(new ClusterTierClientEntityService());
server.registerServerEntityService(new VoltronReadWriteLockServerEntityService());
server.registerClientEntityService(new VoltronReadWriteLockEntityClientService());
server.registerExtendedConfiguration(new OffHeapResourcesProvider(getOffheapResourcesType("test", 32, MemoryUnit.MB)));

UnitTestConnectionService.addServerToStripe(STRIPENAME, server);
}
);

clusterControl.waitForActive();

ClusteringServiceConfiguration configuration =
ClusteringServiceConfigurationBuilder.cluster(URI.create(STRIPE_URI))
.autoCreate()
.build();

service = new ClusteringServiceFactory().create(configuration);

service.start(null);

BaseCacheConfiguration<Long, String> config = new BaseCacheConfiguration<>(Long.class, String.class, noAdvice(), null, noExpiration(),
newResourcePoolsBuilder().with(clusteredDedicated("test", 2, org.ehcache.config.units.MemoryUnit.MB)).build());
ClusteringService.ClusteredCacheIdentifier spaceIdentifier = (ClusteringService.ClusteredCacheIdentifier) service.getPersistenceSpaceIdentifier("test",
config);

ServerStoreProxy serverStoreProxy = service.getServerStoreProxy(spaceIdentifier,
new StoreConfigurationImpl<>(config, 1, null, null),
Consistency.STRONG,
mock(ServerStoreProxy.ServerCallback.class));

SimpleClusterTierClientEntity clientEntity = getEntity(serverStoreProxy);

stateRepository = new ClusterStateRepository(new ClusteringService.ClusteredCacheIdentifier() {
@Override
public String getId() {
return "testStateRepo";
}

@Override
public Class<ClusteringService> getServiceType() {
return ClusteringService.class;
}
}, "test", clientEntity);
}

@After
public void tearDown() throws Exception {
service.stop();
UnitTestConnectionService.removeStripe(STRIPENAME);
clusterControl.tearDown();
}

private static SimpleClusterTierClientEntity getEntity(ServerStoreProxy clusteringService) throws NoSuchFieldException, IllegalAccessException {
Field entity = clusteringService.getClass().getDeclaredField("entity");
entity.setAccessible(true);
return (SimpleClusterTierClientEntity)entity.get(clusteringService);
}

@Test
public void testWhiteListedClass() throws Exception {
StateHolder<Child, Child> testMap = stateRepository.getPersistentStateHolder("testMap", Child.class, Child.class,
Arrays.asList(Child.class, Parent.class)::contains, null);

testMap.putIfAbsent(new Child(10, 20L), new Child(20, 30L));


assertThat(testMap.get(new Child(10, 20L)), is(new Child(20, 30L)));

assertThat(testMap.entrySet(), hasSize(1));
}

@Test
public void testWhiteListedMissingClass() throws Exception {
StateHolder<Child, Child> testMap = stateRepository.getPersistentStateHolder("testMap", Child.class, Child.class,
Arrays.asList(Child.class)::contains, null);

testMap.putIfAbsent(new Child(10, 20L), new Child(20, 30L));

try {
assertThat(testMap.entrySet(), hasSize(1));
} catch (RuntimeException e) {
assertTrue(e.getMessage().equals("Could not load class"));
}
}

@Test
public void testWhitelistingForPrimitiveClass() throws Exception {
// No whitelisting for primitive classes are required as we do not deserialize them at client side
StateHolder<Integer, Integer> testMap = stateRepository.getPersistentStateHolder("testMap", Integer.class, Integer.class,
Arrays.asList(Child.class)::contains, null);

testMap.putIfAbsent(new Integer(10), new Integer(20));

assertThat(testMap.get(new Integer(10)), is(new Integer(20)));
assertThat(testMap.entrySet(), hasSize(1));
}

private static class Parent implements Serializable {
final int val;

private Parent(int val) {
this.val = val;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof Parent)) return false;

Parent testVal = (Parent) o;

return val == testVal.val;
}

@Override
public int hashCode() {
return val;
}
}

private static class Child extends Parent implements Serializable {
final long longValue;

private Child(int val, long longValue) {
super(val);
this.longValue = longValue;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof Child)) return false;

Child testVal = (Child) o;

return super.equals(testVal) && longValue == testVal.longValue;
}

@Override
public int hashCode() {
return super.hashCode() + 31 * Long.hashCode(longValue);
}
}
}
Loading

0 comments on commit 557c5af

Please sign in to comment.