Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions compressai_vision/datasets/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,17 +259,13 @@ def __init__(self, root, dataset_name, imgs_folder, **kwargs):
), "A proper mapper information via cfg must be provided"
mapper = DatasetMapper(kwargs["cfg"], False)

self._org_mapper_func = PicklableWrapper(DatasetMapper(kwargs["cfg"], False))
self._cfg = kwargs["cfg"]
self._org_mapper_func = self._build_mapper_func(self._cfg)

if self.input_agumentation_bypass:
emptyAugList = AugmentationList([])
mapper.augmentations = emptyAugList

if hasattr(self._org_mapper_func, "_obj"):
self._org_mapper_func._obj.augmentations = emptyAugList
else:
self._org_mapper_func.augmentations = emptyAugList

self.mapDataset = MapDataset(_dataset, mapper)

metaData = MetadataCatalog.get(dataset_name)
Expand All @@ -281,7 +277,22 @@ def __init__(self, root, dataset_name, imgs_folder, **kwargs):
except AttributeError:
self.logger.warning("No attribute: thing_classes")

def get_org_mapper_func(self):
def _build_mapper_func(self, cfg):
mapper = DatasetMapper(cfg, False)
if self.input_agumentation_bypass:
mapper.augmentations = AugmentationList([])
return PicklableWrapper(mapper)

def _get_rgb_cfg(self):
cfg = self._cfg.clone()
cfg.defrost()
cfg.INPUT.FORMAT = "RGB"
cfg.freeze()
return cfg

def get_org_mapper_func(self, use_rgb=False):
if use_rgb:
return self._build_mapper_func(self._get_rgb_cfg())
return self._org_mapper_func

def __getitem__(self, idx):
Expand Down Expand Up @@ -327,7 +338,7 @@ def __init__(self, root, dataset_name, imgs_folder, **kwargs):
except AttributeError:
self.logger.warning("No attribute: thing_classes")

def get_org_mapper_func(self):
def get_org_mapper_func(self, use_rgb=False):
return self._org_mapper_func

def __getitem__(self, idx):
Expand Down Expand Up @@ -357,7 +368,7 @@ def __init__(self, root, dataset_name, imgs_folder, **kwargs):
self.mapDataset = MapDataset(_dataset, mapper)
self._org_mapper_func = PicklableWrapper(JDECustomMapper(kwargs["patch_size"]))

def get_org_mapper_func(self):
def get_org_mapper_func(self, use_rgb=False):
return self._org_mapper_func

def __getitem__(self, idx):
Expand Down Expand Up @@ -389,7 +400,6 @@ def __init__(self, root, dataset_name, imgs_folder, **kwargs):
self._org_mapper_func = PicklableWrapper(
YOLOXCustomMapper(kwargs["patch_size"])
)

metaData = MetadataCatalog.get(dataset_name)
try:
self.thing_classes = metaData.thing_classes
Expand All @@ -399,7 +409,9 @@ def __init__(self, root, dataset_name, imgs_folder, **kwargs):
except AttributeError:
self.logger.warning("No attribute: thing_classes")

def get_org_mapper_func(self):
def get_org_mapper_func(self, use_rgb=False):
if use_rgb:
return PicklableWrapper(YOLOXCustomMapper(self.input_size, use_rgb=True))
return self._org_mapper_func

def __getitem__(self, idx):
Expand Down Expand Up @@ -441,7 +453,9 @@ def __init__(self, root, dataset_name, imgs_folder, **kwargs):
except AttributeError:
self.logger.warning("No attribute: thing_classes")

def get_org_mapper_func(self):
def get_org_mapper_func(self, use_rgb=False):
if use_rgb:
return PicklableWrapper(MMPOSECustomMapper(self.input_size, use_rgb=True))
return self._org_mapper_func

def __getitem__(self, idx):
Expand Down
12 changes: 11 additions & 1 deletion compressai_vision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
size_factor=32,
pad_val=[114, 114, 114],
aug_transforms=None,
use_rgb=False,
):
"""
Args:
Expand All @@ -102,6 +103,8 @@ def __init__(
else:
self.aug_transforms = transforms.Compose([transforms.ToTensor()])

self.use_rgb = use_rgb

def compute_scale_and_center(self, src_img_width, src_img_height):
_input_h, _input_w = self.input_img_size
_ratio = src_img_width / src_img_height
Expand Down Expand Up @@ -137,6 +140,8 @@ def __call__(self, dataset_dict):
# tried to replicate the implemetation of the original codes
# Read image
org_img = cv2.imread(dataset_dict["file_name"]) # return img in BGR by default
if self.use_rgb:
org_img = cv2.cvtColor(org_img, cv2.COLOR_BGR2RGB)

assert (
len(org_img.shape) == 3
Expand Down Expand Up @@ -188,7 +193,7 @@ class YOLOXCustomMapper:

"""

def __init__(self, img_size=[640, 640], aug_transforms=None):
def __init__(self, img_size=[640, 640], aug_transforms=None, use_rgb=False):
"""
Args:
img_size: expected input size (Height, Width)
Expand All @@ -201,6 +206,8 @@ def __init__(self, img_size=[640, 640], aug_transforms=None):
else:
self.aug_transforms = transforms.Compose([transforms.ToTensor()])

self.use_rgb = use_rgb

def __call__(self, dataset_dict):
"""
Args:
Expand All @@ -219,6 +226,9 @@ def __call__(self, dataset_dict):
# Read image
org_img = cv2.imread(dataset_dict["file_name"]) # return img in BGR by default

if self.use_rgb:
org_img = cv2.cvtColor(org_img, cv2.COLOR_BGR2RGB)

assert (
len(org_img.shape) == 3
), f"detect an input image with 2 chs, {dataset_dict['file_name']}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@ def __call__(
"org_input_size": org_img_size,
}

resize_mapper = org_map_func if self._compress_after_resizing else None
resize_mapper = (
dataloader.dataset.get_org_mapper_func(use_rgb=True)
if self._compress_after_resizing
else None
)

res, enc_time_details, _ = self._compress(
codec,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ def __call__(

start = time_measure()

resize_mapper = org_map_func if self._compress_after_resizing else None
resize_mapper = (
dataloader.dataset.get_org_mapper_func(use_rgb=True)
if self._compress_after_resizing
else None
)

res, enc_time_by_module, enc_complexity = self._compress(
codec,
Expand Down
Loading