Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
import static org.apache.hadoop.hive.metastore.directsql.MetastoreDirectSqlUtils.extractSqlInt;
import static org.apache.hadoop.hive.metastore.directsql.MetastoreDirectSqlUtils.extractSqlLong;
import static org.apache.hadoop.hive.metastore.directsql.MetastoreDirectSqlUtils.getModelIdentity;
import static org.apache.hadoop.hive.metastore.directsql.MetastoreDirectSqlUtils.makeParams;
import static org.apache.hadoop.hive.metastore.utils.MetaStoreServerUtils.getPartValsFromName;

/**
Expand All @@ -98,10 +99,6 @@ class DirectSqlUpdatePart extends DirectSqlBase {
sqlGenerator = new SQLGenerator(dbType, conf);
}

static String quoteString(String input) {
return "'" + input + "'";
}

private void populateInsertUpdateMap(Map<PartitionInfo, ColumnStatistics> statsPartInfoMap,
Map<PartColNameInfo, MPartitionColumnStatistics> updateMap,
Map<PartColNameInfo, MPartitionColumnStatistics>insertMap,
Expand Down Expand Up @@ -412,35 +409,39 @@ private Map<String, Map<String, String>> updatePartitionParamTable(Connection db

private Map<PartitionInfo, ColumnStatistics> getPartitionInfo(Connection dbConn, long tblId,
Map<String, ColumnStatistics> partColStatsMap)
throws SQLException, MetaException {
List<String> queries = new ArrayList<>();
StringBuilder prefix = new StringBuilder();
StringBuilder suffix = new StringBuilder();
throws MetaException {
Map<PartitionInfo, ColumnStatistics> partitionInfoMap = new HashMap<>();
List<String> partNames = new ArrayList<>(partColStatsMap.keySet());
if (partNames.isEmpty()) {
return partitionInfoMap;
}

List<String> partKeys = partColStatsMap.keySet().stream().map(
e -> quoteString(e)).collect(Collectors.toList()
);

prefix.append("select \"PART_ID\", \"WRITE_ID\", \"PART_NAME\" from \"PARTITIONS\" where ");
suffix.append(" and \"TBL_ID\" = " + tblId);
TxnUtils.buildQueryWithINClauseStrings(conf, queries, prefix, suffix,
partKeys, "\"PART_NAME\"", true, false);

try (Statement statement = dbConn.createStatement()) {
for (String query : queries) {
Batchable.runBatched(maxBatchSize, partNames, new Batchable<String, Void>() {
@Override
public List<Void> run(List<String> input) throws Exception {
String placeholders = makeParams(input.size());
String query = "select \"PART_ID\", \"WRITE_ID\", \"PART_NAME\" from \"PARTITIONS\" where "
+ "\"PART_NAME\" in (" + placeholders + ") and \"TBL_ID\" = ?";
// Select for update makes sure that the partitions are not modified while the stats are getting updated.
query = sqlGenerator.addForUpdateClause(query);
LOG.debug("Execute query: " + query);
try (ResultSet rs = statement.executeQuery(query)) {
while (rs.next()) {
PartitionInfo partitionInfo = new PartitionInfo(rs.getLong(1),
rs.getLong(2), rs.getString(3));
partitionInfoMap.put(partitionInfo, partColStatsMap.get(rs.getString(3)));
try (PreparedStatement ps = dbConn.prepareStatement(query)) {
int paramIndex = 1;
for (String partName : input) {
ps.setString(paramIndex++, partName);
}
ps.setLong(paramIndex, tblId);
try (ResultSet rs = ps.executeQuery()) {
while (rs.next()) {
String partName = rs.getString(3);
PartitionInfo partitionInfo = new PartitionInfo(rs.getLong(1), rs.getLong(2), partName);
partitionInfoMap.put(partitionInfo, partColStatsMap.get(partName));
}
}
}
return Collections.emptyList();
}
}
});
return partitionInfoMap;
}

Expand Down Expand Up @@ -473,6 +474,10 @@ public Map<String, Map<String, String>> updatePartitionColumnStatistics(Map<Stri

Map<PartitionInfo, ColumnStatistics> partitionInfoMap = getPartitionInfo(dbConn, tbl.getId(), partColStatsMap);

if (partitionInfoMap.isEmpty()) {
return Collections.emptyMap();
}

result = updatePartitionParamTable(dbConn, partitionInfoMap, validWriteIds,
writeId, TxnUtils.isAcidTable(tbl));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ public List<Partition> getPartitionsViaPartNames(final String catName, final Str
return Batchable.runBatched(batchSize, partNames, new Batchable<String, Partition>() {
@Override
public List<Partition> run(List<String> input) throws MetaException {
return getPartitionsByNames(catName, dbName, tblName, partNames, false, args);
return getPartitionsByNames(catName, dbName, tblName, input, false, args);
}
});
}
Expand Down Expand Up @@ -1028,9 +1028,7 @@ private List<Partition> getPartitionsByNames(String catName, String dbName,
throws MetaException {
// Get most of the fields for the partNames provided.
// Assume db and table names are the same for all partition, as provided in arguments.
String quotedPartNames = partNameList.stream()
.map(DirectSqlUpdatePart::quoteString)
.collect(Collectors.joining(","));
String partNameParams = makeParams(partNameList.size());

String queryText =
"select " + PARTITIONS + ".\"PART_ID\"," + SDS + ".\"SD_ID\"," + SDS + ".\"CD_ID\","
Expand All @@ -1043,11 +1041,18 @@ private List<Partition> getPartitionsByNames(String catName, String dbName,
+ " left outer join " + SERDES + " on " + SDS + ".\"SERDE_ID\" = " + SERDES + ".\"SERDE_ID\" "
+ " inner join " + TBLS + " on " + TBLS + ".\"TBL_ID\" = " + PARTITIONS + ".\"TBL_ID\" "
+ " inner join " + DBS + " on " + DBS + ".\"DB_ID\" = " + TBLS + ".\"DB_ID\" "
+ " where \"PART_NAME\" in (" + quotedPartNames + ") "
+ " where " + PARTITIONS + ".\"PART_NAME\" in (" + partNameParams + ") "
+ " and " + TBLS + ".\"TBL_NAME\" = ? and " + DBS + ".\"NAME\" = ? and " + DBS
+ ".\"CTLG_NAME\" = ? order by \"PART_NAME\" asc";

Object[] params = new Object[]{tblName, dbName, catName};
Object[] params = new Object[partNameList.size() + 3];
int i = 0;
for (String partName : partNameList) {
params[i++] = partName;
}
params[i++] = tblName;
params[i++] = dbName;
params[i] = catName;
return getPartitionsByQuery(catName, dbName, tblName, queryText, params, isAcidTable, args);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ public class TestObjectStore {
private static final String USER1 = "testobjectstoreuser1";
private static final String ROLE1 = "testobjectstorerole1";
private static final String ROLE2 = "testobjectstorerole2";
private static final String SQLI_PART_NAME = "test_part_col=missing') OR 1=1 -- ";
private static final List<String> ALL_PART_NAMES =
Arrays.asList("test_part_col=a0", "test_part_col=a1", "test_part_col=a2");
private static final Logger LOG = LoggerFactory.getLogger(TestObjectStore.class.getName());

private static final class LongSupplier implements Supplier<Long> {
Expand Down Expand Up @@ -802,6 +805,21 @@ public void testDirectSQLDropPartitionsCacheInSession()
Assert.assertEquals(1, partitions.size());
}

@Test
public void testDirectSQLDropPartitionsRejectsSqlInjectionInPartName()
throws Exception {
createPartitionedTable(false, false, new HashSet<>());

objectStore.dropPartitionsInternal(DEFAULT_CATALOG_NAME, DB1, TABLE1,
Collections.singletonList(SQLI_PART_NAME), true, false);

List<Partition> partitions;
try (AutoCloseable c = deadline()) {
partitions = objectStore.getPartitionsByNames(DEFAULT_CATALOG_NAME, DB1, TABLE1, ALL_PART_NAMES);
}
Assert.assertEquals(3, partitions.size());
}

/**
* Checks if the JDO cache is able to handle directSQL partition drops cross sessions.
*/
Expand Down Expand Up @@ -1024,6 +1042,63 @@ public void testDeletePartitionColumnStatisticsWhenEngineHasSpecialCharacter() t
List.of("test_part_col=a2"), null, "special '");
}

@Test
public void testGetPartitionsByNamesRejectsSqlInjectionInPartName() throws Exception {
createPartitionedTable(true, true, new HashSet<>());
List<Partition> partitions;
try (AutoCloseable c = deadline()) {
partitions = objectStore.getPartitionsByNames(DEFAULT_CATALOG_NAME, DB1, TABLE1,
Collections.singletonList(SQLI_PART_NAME));
}
Assert.assertEquals(0, partitions.size());
try (AutoCloseable c = deadline()) {
partitions = objectStore.getPartitionsByNames(DEFAULT_CATALOG_NAME, DB1, TABLE1, ALL_PART_NAMES);
}
Assert.assertEquals(3, partitions.size());
}

@Test
public void testUpdatePartitionColumnStatisticsInBatchRejectsSqlInjectionInPartName()
throws Exception {
createPartitionedTable(true, true, new HashSet<>());
Table tbl = objectStore.getTable(DEFAULT_CATALOG_NAME, DB1, TABLE1);

List<List<ColumnStatistics>> baseline;
try (AutoCloseable c = deadline()) {
baseline = objectStore.getPartitionColumnStatistics(DEFAULT_CATALOG_NAME, DB1, TABLE1,
ALL_PART_NAMES, Collections.singletonList("test_part_col"));
}
Assert.assertEquals(1, baseline.size());
Assert.assertEquals(3, baseline.get(0).size());
long baselineNumNulls = baseline.get(0).get(0).getStatsObj().get(0).getStatsData()
.getLongStats().getNumNulls();

ColumnStatisticsDesc statsDesc = new ColumnStatisticsDesc(false, DB1, TABLE1);
statsDesc.setCatName(DEFAULT_CATALOG_NAME);
statsDesc.setPartName(SQLI_PART_NAME);
ColumnStatisticsData injectedData = new ColStatsBuilder<>(long.class).numNulls(999).numDVs(2)
.low(3L).high(4L).build();
ColumnStatisticsObj statsObj = new ColumnStatisticsObj("test_part_col", "int", injectedData);
ColumnStatistics maliciousStats = new ColumnStatistics(statsDesc,
Collections.singletonList(statsObj));
maliciousStats.setEngine(ENGINE);

Map<String, ColumnStatistics> statsMap = new HashMap<>();
statsMap.put(SQLI_PART_NAME, maliciousStats);
objectStore.updatePartitionColumnStatisticsInBatch(statsMap, tbl, null, null, -1);

List<List<ColumnStatistics>> after;
try (AutoCloseable c = deadline()) {
after = objectStore.getPartitionColumnStatistics(DEFAULT_CATALOG_NAME, DB1, TABLE1,
ALL_PART_NAMES, Collections.singletonList("test_part_col"));
}
Assert.assertEquals(3, after.get(0).size());
for (ColumnStatistics cs : after.get(0)) {
Assert.assertEquals(baselineNumNulls,
cs.getStatsObj().get(0).getStatsData().getLongStats().getNumNulls());
}
}

private void setAggrConf(boolean enableBitVector, boolean enableKll, int batchSize) {
Configuration conf2 = MetastoreConf.newMetastoreConf(conf);
MetastoreConf.setBoolVar(conf2, ConfVars.STATS_FETCH_BITVECTOR, enableBitVector);
Expand Down
Loading