diff --git a/standalone-metastore/metastore-server/src/main/java/org/apache/hadoop/hive/metastore/directsql/DirectSqlUpdatePart.java b/standalone-metastore/metastore-server/src/main/java/org/apache/hadoop/hive/metastore/directsql/DirectSqlUpdatePart.java index 4a7f831d8d04..94926c01564b 100644 --- a/standalone-metastore/metastore-server/src/main/java/org/apache/hadoop/hive/metastore/directsql/DirectSqlUpdatePart.java +++ b/standalone-metastore/metastore-server/src/main/java/org/apache/hadoop/hive/metastore/directsql/DirectSqlUpdatePart.java @@ -77,6 +77,7 @@ import static org.apache.hadoop.hive.metastore.directsql.MetastoreDirectSqlUtils.extractSqlInt; import static org.apache.hadoop.hive.metastore.directsql.MetastoreDirectSqlUtils.extractSqlLong; import static org.apache.hadoop.hive.metastore.directsql.MetastoreDirectSqlUtils.getModelIdentity; +import static org.apache.hadoop.hive.metastore.directsql.MetastoreDirectSqlUtils.makeParams; import static org.apache.hadoop.hive.metastore.utils.MetaStoreServerUtils.getPartValsFromName; /** @@ -98,10 +99,6 @@ class DirectSqlUpdatePart extends DirectSqlBase { sqlGenerator = new SQLGenerator(dbType, conf); } - static String quoteString(String input) { - return "'" + input + "'"; - } - private void populateInsertUpdateMap(Map statsPartInfoMap, Map updateMap, MapinsertMap, @@ -412,35 +409,39 @@ private Map> updatePartitionParamTable(Connection db private Map getPartitionInfo(Connection dbConn, long tblId, Map partColStatsMap) - throws SQLException, MetaException { - List queries = new ArrayList<>(); - StringBuilder prefix = new StringBuilder(); - StringBuilder suffix = new StringBuilder(); + throws MetaException { Map partitionInfoMap = new HashMap<>(); + List partNames = new ArrayList<>(partColStatsMap.keySet()); + if (partNames.isEmpty()) { + return partitionInfoMap; + } - List partKeys = partColStatsMap.keySet().stream().map( - e -> quoteString(e)).collect(Collectors.toList() - ); - - prefix.append("select \"PART_ID\", \"WRITE_ID\", \"PART_NAME\" from \"PARTITIONS\" where "); - suffix.append(" and \"TBL_ID\" = " + tblId); - TxnUtils.buildQueryWithINClauseStrings(conf, queries, prefix, suffix, - partKeys, "\"PART_NAME\"", true, false); - - try (Statement statement = dbConn.createStatement()) { - for (String query : queries) { + Batchable.runBatched(maxBatchSize, partNames, new Batchable() { + @Override + public List run(List input) throws Exception { + String placeholders = makeParams(input.size()); + String query = "select \"PART_ID\", \"WRITE_ID\", \"PART_NAME\" from \"PARTITIONS\" where " + + "\"PART_NAME\" in (" + placeholders + ") and \"TBL_ID\" = ?"; // Select for update makes sure that the partitions are not modified while the stats are getting updated. query = sqlGenerator.addForUpdateClause(query); LOG.debug("Execute query: " + query); - try (ResultSet rs = statement.executeQuery(query)) { - while (rs.next()) { - PartitionInfo partitionInfo = new PartitionInfo(rs.getLong(1), - rs.getLong(2), rs.getString(3)); - partitionInfoMap.put(partitionInfo, partColStatsMap.get(rs.getString(3))); + try (PreparedStatement ps = dbConn.prepareStatement(query)) { + int paramIndex = 1; + for (String partName : input) { + ps.setString(paramIndex++, partName); + } + ps.setLong(paramIndex, tblId); + try (ResultSet rs = ps.executeQuery()) { + while (rs.next()) { + String partName = rs.getString(3); + PartitionInfo partitionInfo = new PartitionInfo(rs.getLong(1), rs.getLong(2), partName); + partitionInfoMap.put(partitionInfo, partColStatsMap.get(partName)); + } } } + return Collections.emptyList(); } - } + }); return partitionInfoMap; } @@ -473,6 +474,10 @@ public Map> updatePartitionColumnStatistics(Map partitionInfoMap = getPartitionInfo(dbConn, tbl.getId(), partColStatsMap); + if (partitionInfoMap.isEmpty()) { + return Collections.emptyMap(); + } + result = updatePartitionParamTable(dbConn, partitionInfoMap, validWriteIds, writeId, TxnUtils.isAcidTable(tbl)); diff --git a/standalone-metastore/metastore-server/src/main/java/org/apache/hadoop/hive/metastore/directsql/MetaStoreDirectSql.java b/standalone-metastore/metastore-server/src/main/java/org/apache/hadoop/hive/metastore/directsql/MetaStoreDirectSql.java index 3a317104ad8e..713f3ff4e338 100644 --- a/standalone-metastore/metastore-server/src/main/java/org/apache/hadoop/hive/metastore/directsql/MetaStoreDirectSql.java +++ b/standalone-metastore/metastore-server/src/main/java/org/apache/hadoop/hive/metastore/directsql/MetaStoreDirectSql.java @@ -738,7 +738,7 @@ public List getPartitionsViaPartNames(final String catName, final Str return Batchable.runBatched(batchSize, partNames, new Batchable() { @Override public List run(List input) throws MetaException { - return getPartitionsByNames(catName, dbName, tblName, partNames, false, args); + return getPartitionsByNames(catName, dbName, tblName, input, false, args); } }); } @@ -1028,9 +1028,7 @@ private List getPartitionsByNames(String catName, String dbName, throws MetaException { // Get most of the fields for the partNames provided. // Assume db and table names are the same for all partition, as provided in arguments. - String quotedPartNames = partNameList.stream() - .map(DirectSqlUpdatePart::quoteString) - .collect(Collectors.joining(",")); + String partNameParams = makeParams(partNameList.size()); String queryText = "select " + PARTITIONS + ".\"PART_ID\"," + SDS + ".\"SD_ID\"," + SDS + ".\"CD_ID\"," @@ -1043,11 +1041,18 @@ private List getPartitionsByNames(String catName, String dbName, + " left outer join " + SERDES + " on " + SDS + ".\"SERDE_ID\" = " + SERDES + ".\"SERDE_ID\" " + " inner join " + TBLS + " on " + TBLS + ".\"TBL_ID\" = " + PARTITIONS + ".\"TBL_ID\" " + " inner join " + DBS + " on " + DBS + ".\"DB_ID\" = " + TBLS + ".\"DB_ID\" " - + " where \"PART_NAME\" in (" + quotedPartNames + ") " + + " where " + PARTITIONS + ".\"PART_NAME\" in (" + partNameParams + ") " + " and " + TBLS + ".\"TBL_NAME\" = ? and " + DBS + ".\"NAME\" = ? and " + DBS + ".\"CTLG_NAME\" = ? order by \"PART_NAME\" asc"; - Object[] params = new Object[]{tblName, dbName, catName}; + Object[] params = new Object[partNameList.size() + 3]; + int i = 0; + for (String partName : partNameList) { + params[i++] = partName; + } + params[i++] = tblName; + params[i++] = dbName; + params[i] = catName; return getPartitionsByQuery(catName, dbName, tblName, queryText, params, isAcidTable, args); } diff --git a/standalone-metastore/metastore-server/src/test/java/org/apache/hadoop/hive/metastore/TestObjectStore.java b/standalone-metastore/metastore-server/src/test/java/org/apache/hadoop/hive/metastore/TestObjectStore.java index afede2f768c0..bbbf5f4bb326 100644 --- a/standalone-metastore/metastore-server/src/test/java/org/apache/hadoop/hive/metastore/TestObjectStore.java +++ b/standalone-metastore/metastore-server/src/test/java/org/apache/hadoop/hive/metastore/TestObjectStore.java @@ -149,6 +149,9 @@ public class TestObjectStore { private static final String USER1 = "testobjectstoreuser1"; private static final String ROLE1 = "testobjectstorerole1"; private static final String ROLE2 = "testobjectstorerole2"; + private static final String SQLI_PART_NAME = "test_part_col=missing') OR 1=1 -- "; + private static final List ALL_PART_NAMES = + Arrays.asList("test_part_col=a0", "test_part_col=a1", "test_part_col=a2"); private static final Logger LOG = LoggerFactory.getLogger(TestObjectStore.class.getName()); private static final class LongSupplier implements Supplier { @@ -802,6 +805,21 @@ public void testDirectSQLDropPartitionsCacheInSession() Assert.assertEquals(1, partitions.size()); } + @Test + public void testDirectSQLDropPartitionsRejectsSqlInjectionInPartName() + throws Exception { + createPartitionedTable(false, false, new HashSet<>()); + + objectStore.dropPartitionsInternal(DEFAULT_CATALOG_NAME, DB1, TABLE1, + Collections.singletonList(SQLI_PART_NAME), true, false); + + List partitions; + try (AutoCloseable c = deadline()) { + partitions = objectStore.getPartitionsByNames(DEFAULT_CATALOG_NAME, DB1, TABLE1, ALL_PART_NAMES); + } + Assert.assertEquals(3, partitions.size()); + } + /** * Checks if the JDO cache is able to handle directSQL partition drops cross sessions. */ @@ -1024,6 +1042,63 @@ public void testDeletePartitionColumnStatisticsWhenEngineHasSpecialCharacter() t List.of("test_part_col=a2"), null, "special '"); } + @Test + public void testGetPartitionsByNamesRejectsSqlInjectionInPartName() throws Exception { + createPartitionedTable(true, true, new HashSet<>()); + List partitions; + try (AutoCloseable c = deadline()) { + partitions = objectStore.getPartitionsByNames(DEFAULT_CATALOG_NAME, DB1, TABLE1, + Collections.singletonList(SQLI_PART_NAME)); + } + Assert.assertEquals(0, partitions.size()); + try (AutoCloseable c = deadline()) { + partitions = objectStore.getPartitionsByNames(DEFAULT_CATALOG_NAME, DB1, TABLE1, ALL_PART_NAMES); + } + Assert.assertEquals(3, partitions.size()); + } + + @Test + public void testUpdatePartitionColumnStatisticsInBatchRejectsSqlInjectionInPartName() + throws Exception { + createPartitionedTable(true, true, new HashSet<>()); + Table tbl = objectStore.getTable(DEFAULT_CATALOG_NAME, DB1, TABLE1); + + List> baseline; + try (AutoCloseable c = deadline()) { + baseline = objectStore.getPartitionColumnStatistics(DEFAULT_CATALOG_NAME, DB1, TABLE1, + ALL_PART_NAMES, Collections.singletonList("test_part_col")); + } + Assert.assertEquals(1, baseline.size()); + Assert.assertEquals(3, baseline.get(0).size()); + long baselineNumNulls = baseline.get(0).get(0).getStatsObj().get(0).getStatsData() + .getLongStats().getNumNulls(); + + ColumnStatisticsDesc statsDesc = new ColumnStatisticsDesc(false, DB1, TABLE1); + statsDesc.setCatName(DEFAULT_CATALOG_NAME); + statsDesc.setPartName(SQLI_PART_NAME); + ColumnStatisticsData injectedData = new ColStatsBuilder<>(long.class).numNulls(999).numDVs(2) + .low(3L).high(4L).build(); + ColumnStatisticsObj statsObj = new ColumnStatisticsObj("test_part_col", "int", injectedData); + ColumnStatistics maliciousStats = new ColumnStatistics(statsDesc, + Collections.singletonList(statsObj)); + maliciousStats.setEngine(ENGINE); + + Map statsMap = new HashMap<>(); + statsMap.put(SQLI_PART_NAME, maliciousStats); + objectStore.updatePartitionColumnStatisticsInBatch(statsMap, tbl, null, null, -1); + + List> after; + try (AutoCloseable c = deadline()) { + after = objectStore.getPartitionColumnStatistics(DEFAULT_CATALOG_NAME, DB1, TABLE1, + ALL_PART_NAMES, Collections.singletonList("test_part_col")); + } + Assert.assertEquals(3, after.get(0).size()); + for (ColumnStatistics cs : after.get(0)) { + Assert.assertEquals(baselineNumNulls, + cs.getStatsObj().get(0).getStatsData().getLongStats().getNumNulls()); + } + } + private void setAggrConf(boolean enableBitVector, boolean enableKll, int batchSize) { Configuration conf2 = MetastoreConf.newMetastoreConf(conf); MetastoreConf.setBoolVar(conf2, ConfVars.STATS_FETCH_BITVECTOR, enableBitVector);