diff --git a/src/main/java/org/apache/commons/lang3/Range.java b/src/main/java/org/apache/commons/lang3/Range.java index 2c1743956d7..e0fca54d101 100644 --- a/src/main/java/org/apache/commons/lang3/Range.java +++ b/src/main/java/org/apache/commons/lang3/Range.java @@ -550,6 +550,15 @@ private void readObject(final ObjectInputStream in) throws IOException, ClassNot if (hashCode != hash(minimum, maximum)) { throw new InvalidObjectException("Range hashCode does not match minimum/maximum."); } + if (maximum == null) { + throw new InvalidObjectException("maximum null"); + } + if (minimum == null) { + throw new InvalidObjectException("minimum null"); + } + if (comparator == null) { + throw new InvalidObjectException("comparator null"); + } } /** diff --git a/src/test/java/org/apache/commons/lang3/RangeReadObjectTest.java b/src/test/java/org/apache/commons/lang3/RangeReadObjectTest.java index 866270b596f..be2866ff253 100644 --- a/src/test/java/org/apache/commons/lang3/RangeReadObjectTest.java +++ b/src/test/java/org/apache/commons/lang3/RangeReadObjectTest.java @@ -21,7 +21,14 @@ import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertThrows; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.io.InvalidObjectException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.util.Objects; import org.apache.commons.lang3.reflect.FieldUtils; import org.junit.jupiter.api.Test; @@ -31,6 +38,59 @@ */ class RangeReadObjectTest { + /** + * Standin class used only to drive {@link ObjectOutputStream#writeObject(Object)} into emitting a stream that matches the wire format of {@link Range} but + * with caller-controlled field values. The class name and serialVersionUID are spoofed in the stream below via a custom {@link ObjectOutputStream} subclass + * so the stream reads back as a {@code Range}. + */ + private static final class RangeForge implements Serializable { + + private static final long serialVersionUID = 2L; // matches Range.serialVersionUID + private final Object comparator; + private final int hashCode; + private final Object maximum; + private final Object minimum; + + RangeForge(final Object comparator, final Object minimum, final Object maximum, final int hashCode) { + this.comparator = comparator; + this.minimum = minimum; + this.maximum = maximum; + this.hashCode = hashCode; + } + } + + private static Object deserialize(final byte[] bytes) throws IOException, ClassNotFoundException { + try (ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes))) { + return ois.readObject(); + } + } + + /** + * Serializes a {@link RangeForge} but rewrites the class descriptor name to "org.apache.commons.lang3.Range" so the resulting bytes deserialize as a + * {@link Range}. Because the field set, types, order, and serialVersionUID all match, default deserialization assigns each forged value to the + * corresponding Range field via reflection (bypassing the constructor). + */ + private static byte[] forgeRangeStream(final Object comparator, final Object minimum, final Object maximum, final int hashCode) throws IOException { + // Build the legitimate-shape bytes via RangeForge, then rewrite the embedded class name. + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (ObjectOutputStream oos = new ObjectOutputStream(baos) { + + @Override + protected void writeClassDescriptor(final java.io.ObjectStreamClass desc) throws IOException { + if (desc.getName().equals(RangeForge.class.getName())) { + // Emit a descriptor whose name is Range but whose field layout still matches RangeForge. + final java.io.ObjectStreamClass spoofed = java.io.ObjectStreamClass.lookup(Range.class); + super.writeClassDescriptor(spoofed); + } else { + super.writeClassDescriptor(desc); + } + } + }) { + oos.writeObject(new RangeForge(comparator, minimum, maximum, hashCode)); + } + return baos.toByteArray(); + } + @Test void testBadHashCodeRejected() throws Exception { final Range range = Range.of(1, 100); @@ -44,6 +104,47 @@ void testBadHashCodeRejected() throws Exception { assertEquals("java.io.InvalidObjectException: Range hashCode does not match minimum/maximum.", ex.getMessage()); } + /** + * Forged stream with {@code comparator == null}; F-004 hashCode check passes because we set hashCode canonically; {@code contains()} then NPEs on + * {@code comparator.compare(...)}. + */ + @Test + void testComparatorNullViaForgedStream() throws Exception { + final Integer min = Integer.valueOf(1); + final Integer max = Integer.valueOf(10); + final int canonicalHash = Objects.hash(min, max); + final byte[] forged = forgeRangeStream(null, min, max, canonicalHash); + assertThrows(InvalidObjectException.class, () -> deserialize(forged)); + } + + /** + * Forged stream with {@code maximum == null}; symmetric to F-061b. + */ + @Test + void testMaximumNullViaForgedStream() throws Exception { + final Integer min = Integer.valueOf(1); + final int canonicalHash = Objects.hash(min, (Object) null); + final Object comparator = Range.of(Integer.valueOf(1), Integer.valueOf(2)).getComparator(); + final byte[] forged = forgeRangeStream(comparator, min, null, canonicalHash); + assertThrows(InvalidObjectException.class, () -> deserialize(forged)); + } + + /** + * Forged stream with {@code minimum == null}; {@code Objects.hash(null, max)} is a valid int, so the F-004 check passes. {@code contains()} NPEs + * because {@code comparator.compare(element, null)} unboxes null (or, for ComparableComparator, calls {@code element.compareTo(null)} which is an + * NPE-by-contract). + */ + @Test + void testMinimumNullViaForgedStream() throws Exception { + final Integer max = Integer.valueOf(10); + final int canonicalHash = Objects.hash((Object) null, max); + // comparator must be non-null here so we isolate the minimum-null gap. + // We use ComparableComparator.INSTANCE via deserialization round-trip of a real Range. + final Object comparator = Range.of(Integer.valueOf(1), Integer.valueOf(2)).getComparator(); + final byte[] forged = forgeRangeStream(comparator, null, max, canonicalHash); + assertThrows(InvalidObjectException.class, () -> deserialize(forged)); + } + @Test void testRoundTripPreservesCorrectHashCode() throws Exception { final Range range = Range.of("apple", "mango");