From a4faee74d9fa5b5ce25f92469d36c4e97afa825f Mon Sep 17 00:00:00 2001 From: Jianwen Song Date: Mon, 18 May 2026 10:19:54 +1000 Subject: [PATCH] fix: use rgb resized image before remote for resizing mode --- compressai_vision/datasets/image.py | 38 +++++++++++++------ compressai_vision/datasets/utils.py | 12 +++++- .../image_remote_inference.py | 6 ++- .../video_remote_inference.py | 6 ++- 4 files changed, 47 insertions(+), 15 deletions(-) diff --git a/compressai_vision/datasets/image.py b/compressai_vision/datasets/image.py index 90d64df..6fbabcb 100644 --- a/compressai_vision/datasets/image.py +++ b/compressai_vision/datasets/image.py @@ -259,17 +259,13 @@ def __init__(self, root, dataset_name, imgs_folder, **kwargs): ), "A proper mapper information via cfg must be provided" mapper = DatasetMapper(kwargs["cfg"], False) - self._org_mapper_func = PicklableWrapper(DatasetMapper(kwargs["cfg"], False)) + self._cfg = kwargs["cfg"] + self._org_mapper_func = self._build_mapper_func(self._cfg) if self.input_agumentation_bypass: emptyAugList = AugmentationList([]) mapper.augmentations = emptyAugList - if hasattr(self._org_mapper_func, "_obj"): - self._org_mapper_func._obj.augmentations = emptyAugList - else: - self._org_mapper_func.augmentations = emptyAugList - self.mapDataset = MapDataset(_dataset, mapper) metaData = MetadataCatalog.get(dataset_name) @@ -281,7 +277,22 @@ def __init__(self, root, dataset_name, imgs_folder, **kwargs): except AttributeError: self.logger.warning("No attribute: thing_classes") - def get_org_mapper_func(self): + def _build_mapper_func(self, cfg): + mapper = DatasetMapper(cfg, False) + if self.input_agumentation_bypass: + mapper.augmentations = AugmentationList([]) + return PicklableWrapper(mapper) + + def _get_rgb_cfg(self): + cfg = self._cfg.clone() + cfg.defrost() + cfg.INPUT.FORMAT = "RGB" + cfg.freeze() + return cfg + + def get_org_mapper_func(self, use_rgb=False): + if use_rgb: + return self._build_mapper_func(self._get_rgb_cfg()) return self._org_mapper_func def __getitem__(self, idx): @@ -327,7 +338,7 @@ def __init__(self, root, dataset_name, imgs_folder, **kwargs): except AttributeError: self.logger.warning("No attribute: thing_classes") - def get_org_mapper_func(self): + def get_org_mapper_func(self, use_rgb=False): return self._org_mapper_func def __getitem__(self, idx): @@ -357,7 +368,7 @@ def __init__(self, root, dataset_name, imgs_folder, **kwargs): self.mapDataset = MapDataset(_dataset, mapper) self._org_mapper_func = PicklableWrapper(JDECustomMapper(kwargs["patch_size"])) - def get_org_mapper_func(self): + def get_org_mapper_func(self, use_rgb=False): return self._org_mapper_func def __getitem__(self, idx): @@ -389,7 +400,6 @@ def __init__(self, root, dataset_name, imgs_folder, **kwargs): self._org_mapper_func = PicklableWrapper( YOLOXCustomMapper(kwargs["patch_size"]) ) - metaData = MetadataCatalog.get(dataset_name) try: self.thing_classes = metaData.thing_classes @@ -399,7 +409,9 @@ def __init__(self, root, dataset_name, imgs_folder, **kwargs): except AttributeError: self.logger.warning("No attribute: thing_classes") - def get_org_mapper_func(self): + def get_org_mapper_func(self, use_rgb=False): + if use_rgb: + return PicklableWrapper(YOLOXCustomMapper(self.input_size, use_rgb=True)) return self._org_mapper_func def __getitem__(self, idx): @@ -441,7 +453,9 @@ def __init__(self, root, dataset_name, imgs_folder, **kwargs): except AttributeError: self.logger.warning("No attribute: thing_classes") - def get_org_mapper_func(self): + def get_org_mapper_func(self, use_rgb=False): + if use_rgb: + return PicklableWrapper(MMPOSECustomMapper(self.input_size, use_rgb=True)) return self._org_mapper_func def __getitem__(self, idx): diff --git a/compressai_vision/datasets/utils.py b/compressai_vision/datasets/utils.py index 81a37ce..e4c0a45 100644 --- a/compressai_vision/datasets/utils.py +++ b/compressai_vision/datasets/utils.py @@ -87,6 +87,7 @@ def __init__( size_factor=32, pad_val=[114, 114, 114], aug_transforms=None, + use_rgb=False, ): """ Args: @@ -102,6 +103,8 @@ def __init__( else: self.aug_transforms = transforms.Compose([transforms.ToTensor()]) + self.use_rgb = use_rgb + def compute_scale_and_center(self, src_img_width, src_img_height): _input_h, _input_w = self.input_img_size _ratio = src_img_width / src_img_height @@ -137,6 +140,8 @@ def __call__(self, dataset_dict): # tried to replicate the implemetation of the original codes # Read image org_img = cv2.imread(dataset_dict["file_name"]) # return img in BGR by default + if self.use_rgb: + org_img = cv2.cvtColor(org_img, cv2.COLOR_BGR2RGB) assert ( len(org_img.shape) == 3 @@ -188,7 +193,7 @@ class YOLOXCustomMapper: """ - def __init__(self, img_size=[640, 640], aug_transforms=None): + def __init__(self, img_size=[640, 640], aug_transforms=None, use_rgb=False): """ Args: img_size: expected input size (Height, Width) @@ -201,6 +206,8 @@ def __init__(self, img_size=[640, 640], aug_transforms=None): else: self.aug_transforms = transforms.Compose([transforms.ToTensor()]) + self.use_rgb = use_rgb + def __call__(self, dataset_dict): """ Args: @@ -219,6 +226,9 @@ def __call__(self, dataset_dict): # Read image org_img = cv2.imread(dataset_dict["file_name"]) # return img in BGR by default + if self.use_rgb: + org_img = cv2.cvtColor(org_img, cv2.COLOR_BGR2RGB) + assert ( len(org_img.shape) == 3 ), f"detect an input image with 2 chs, {dataset_dict['file_name']}" diff --git a/compressai_vision/pipelines/remote_inference/image_remote_inference.py b/compressai_vision/pipelines/remote_inference/image_remote_inference.py index 47d4ec6..bc8085f 100644 --- a/compressai_vision/pipelines/remote_inference/image_remote_inference.py +++ b/compressai_vision/pipelines/remote_inference/image_remote_inference.py @@ -115,7 +115,11 @@ def __call__( "org_input_size": org_img_size, } - resize_mapper = org_map_func if self._compress_after_resizing else None + resize_mapper = ( + dataloader.dataset.get_org_mapper_func(use_rgb=True) + if self._compress_after_resizing + else None + ) res, enc_time_details, _ = self._compress( codec, diff --git a/compressai_vision/pipelines/remote_inference/video_remote_inference.py b/compressai_vision/pipelines/remote_inference/video_remote_inference.py index 3df9366..e5180ca 100644 --- a/compressai_vision/pipelines/remote_inference/video_remote_inference.py +++ b/compressai_vision/pipelines/remote_inference/video_remote_inference.py @@ -128,7 +128,11 @@ def __call__( start = time_measure() - resize_mapper = org_map_func if self._compress_after_resizing else None + resize_mapper = ( + dataloader.dataset.get_org_mapper_func(use_rgb=True) + if self._compress_after_resizing + else None + ) res, enc_time_by_module, enc_complexity = self._compress( codec,