diff --git a/04_pep8/submissions/jeon193.py b/04_pep8/submissions/jeon193.py new file mode 100644 index 0000000..2b1eab4 --- /dev/null +++ b/04_pep8/submissions/jeon193.py @@ -0,0 +1,314 @@ +# Modifiedy from detectron2i +# https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/test_time_augmentation.py +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +from contextlib import contextmanager +from itertools import count +from typing import List + +import numpy as np +import torch +from detectron2.config import configurable +from detectron2.data.detection_utils import read_image +from detectron2.data.transforms import ( + RandomFlip, + ResizeShortestEdge, + ResizeTransform, + apply_augmentations, +) +from detectron2.modeling.meta_arch import GeneralizedRCNN +from detectron2.modeling.postprocessing import detector_postprocess +from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference_single_image +from detectron2.structures import Boxes, Instances +from fvcore.transforms import HFlipTransform, NoOpTransform +from torch import nn +from torch.nn.parallel import DistributedDataParallel + +__all__ = ["DatasetMapperTTA", "GeneralizedRCNNWithTTA"] + + +# fmt: off +DUMMY_LIST1 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33] # noqa: E501 +DUMMY_LIST2 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33] # noqa: E501 +DUMMY_LIST3 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33] # noqa: E501 +print(DUMMY_LIST1, DUMMY_LIST2, DUMMY_LIST3) +# fmt: on + + +class DatasetMapperTTA: + """ + Implement test-time augmentation for detection data. + It is a callable which takes a dataset dict from a detection dataset, + and returns a list of dataset dicts where the images + are augmented from the input image by the transformations defined in the config. + This is used for test-time augmentation. + """ + + @configurable + def __init__(self, min_sizes: List[int], max_size: int, flip: bool): + """ + Args: + min_sizes: list of short-edge size to resize the image to + max_size: maximum height or width of resized images + flip: whether to apply flipping augmentation + """ + self.min_sizes = min_sizes + self.max_size = max_size + self.flip = flip + + @classmethod + def from_config(cls, cfg): + return { + "min_sizes": cfg.TEST.AUG.MIN_SIZES, + "max_size": cfg.TEST.AUG.MAX_SIZE, + "flip": cfg.TEST.AUG.FLIP, + } + + def __call__(self, dataset_dict): + """ + Args: + dict: a dict in standard model input format. See tutorials for details. + + Returns: + list[dict]: + a list of dicts, which contain augmented version of the input image. + The total number of dicts is ``len(min_sizes) * (2 if flip else 1)``. + Each dict has field "transforms" which is a TransformList, + containing the transforms that are used to generate this image. + """ + numpy_image = dataset_dict["image"].permute(1, 2, 0).numpy() + shape = numpy_image.shape + orig_shape = (dataset_dict["height"], dataset_dict["width"]) + if shape[:2] != orig_shape: + # It transforms the "original" image in the dataset to the input image + pre_tfm = ResizeTransform(orig_shape[0], orig_shape[1], shape[0], shape[1]) + else: + pre_tfm = NoOpTransform() + + # Create all combinations of augmentations to use + aug_candidates = [] # each element is a list[Augmentation] + for min_size in self.min_sizes: + resize = ResizeShortestEdge(min_size, self.max_size) + aug_candidates.append([resize]) # resize only + if self.flip: + flip = RandomFlip(prob=1.0) + aug_candidates.append([resize, flip]) # resize + flip + + # Apply all the augmentations + ret = [] + for aug in aug_candidates: + new_image, tfms = apply_augmentations(aug, np.copy(numpy_image)) + torch_image = torch.from_numpy(np.ascontiguousarray(new_image.transpose(2, 0, 1))) + + dic = copy.deepcopy(dataset_dict) + dic["transforms"] = pre_tfm + tfms + dic["image"] = torch_image + ret.append(dic) + return ret + + +class GeneralizedRCNNWithTTA(nn.Module): + """ + A GeneralizedRCNN with test-time augmentation enabled. + Its :meth:`__call__` method has the same interface as :meth:`GeneralizedRCNN.forward`. + """ + + def __init__(self, cfg, model, tta_mapper=None, batch_size=3): + """ + Args: + cfg (CfgNode): + model (GeneralizedRCNN): a GeneralizedRCNN to apply TTA on. + tta_mapper (callable): takes a dataset dict and returns a list of + augmented versions of the dataset dict. Defaults to + `DatasetMapperTTA(cfg)`. + batch_size (int): batch the augmented images into this batch size for inference. + """ + super().__init__() + if isinstance(model, DistributedDataParallel): + model = model.module + assert isinstance( + model, GeneralizedRCNN + ), "TTA is only supported on GeneralizedRCNN. Got a model of type {}".format(type(model)) + self.cfg = cfg.clone() + assert not self.cfg.MODEL.KEYPOINT_ON, "TTA for keypoint is not supported yet" + assert not self.cfg.MODEL.LOAD_PROPOSALS, "TTA for pre-computed proposals is not supported yet" # noqa: E501 + + self.model = model + + if tta_mapper is None: + tta_mapper = DatasetMapperTTA(cfg) + self.tta_mapper = tta_mapper + self.batch_size = batch_size + + @contextmanager + def _turn_off_roi_heads(self, attrs): + """ + Open a context where some heads in `model.roi_heads` are temporarily turned off. + Args: + attr (list[str]): the attribute in `model.roi_heads` which can be used + to turn off a specific head, e.g., "mask_on", "keypoint_on". + """ + roi_heads = self.model.roi_heads + old = {} + for attr in attrs: + try: + old[attr] = getattr(roi_heads, attr) + except AttributeError: + # The head may not be implemented in certain ROIHeads + pass + + if len(old.keys()) == 0: + yield + else: + for attr in old.keys(): + setattr(roi_heads, attr, False) + yield + for attr in old.keys(): + setattr(roi_heads, attr, old[attr]) + + def _batch_inference(self, batched_inputs, detected_instances=None): + """ + Execute inference on a list of inputs, + using batch size = self.batch_size, instead of the length of the list. + + Inputs & outputs have the same format as :meth:`GeneralizedRCNN.inference` + """ + if detected_instances is None: + detected_instances = [None] * len(batched_inputs) + + outputs = [] + inputs, instances = [], [] + for idx, input, instance in zip(count(), batched_inputs, detected_instances): + inputs.append(input) + instances.append(instance) + if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1: + outputs.extend( + self.model.inference( + inputs, + instances if instances[0] is not None else None, + do_postprocess=False, + ) + ) + inputs, instances = [], [] + return outputs + + def __call__(self, batched_inputs): + """ + Same input/output format as :meth:`GeneralizedRCNN.forward` + """ + + def _maybe_read_image(dataset_dict): + ret = copy.copy(dataset_dict) + if "image" not in ret: + image = read_image(ret.pop("file_name"), self.model.input_format) + image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1))) # CHW + ret["image"] = image + if "height" not in ret and "width" not in ret: + ret["height"] = image.shape[1] + ret["width"] = image.shape[2] + return ret + + return [self._inference_one_image(_maybe_read_image(x)) for x in batched_inputs] + + def _inference_one_image(self, input): + """ + Args: + input (dict): one dataset dict with "image" field being a CHW tensor + + Returns: + dict: one output dict + """ + orig_shape = (input["height"], input["width"]) + augmented_inputs, tfms = self._get_augmented_inputs(input) + # Detect boxes from all augmented versions + with self._turn_off_roi_heads(["mask_on", "keypoint_on"]): + # temporarily disable roi heads + all_boxes, all_scores, all_classes = self._get_augmented_boxes(augmented_inputs, tfms) + # merge all detected boxes to obtain final predictions for boxes + merged_instances = self._merge_detections(all_boxes, all_scores, all_classes, orig_shape) + + if self.cfg.MODEL.MASK_ON: + # Use the detected boxes to obtain masks + augmented_instances = self._rescale_detected_boxes( + augmented_inputs, merged_instances, tfms + ) + # run forward on the detected boxes + outputs = self._batch_inference(augmented_inputs, augmented_instances) + # Delete now useless variables to avoid being out of memory + del augmented_inputs, augmented_instances + # average the predictions + merged_instances.pred_masks = self._reduce_pred_masks(outputs, tfms) + merged_instances = detector_postprocess(merged_instances, *orig_shape) + return {"instances": merged_instances} + else: + return {"instances": merged_instances} + + def _get_augmented_inputs(self, input): + augmented_inputs = self.tta_mapper(input) + tfms = [x.pop("transforms") for x in augmented_inputs] + return augmented_inputs, tfms + + def _get_augmented_boxes(self, augmented_inputs, tfms): + # 1: forward with all augmented images + outputs = self._batch_inference(augmented_inputs) + # 2: union the results + all_boxes = [] + all_scores = [] + all_classes = [] + for output, tfm in zip(outputs, tfms): + # Need to inverse the transforms on boxes, to obtain results on original image + pred_boxes = output.pred_boxes.tensor + original_pred_boxes = tfm.inverse().apply_box(pred_boxes.cpu().numpy()) + all_boxes.append(torch.from_numpy(original_pred_boxes).to(pred_boxes.device)) + + all_scores.extend(output.scores) + all_classes.extend(output.pred_classes) + all_boxes = torch.cat(all_boxes, dim=0) + return all_boxes, all_scores, all_classes + + def _merge_detections(self, all_boxes, all_scores, all_classes, shape_hw): + # select from the union of all results + num_boxes = len(all_boxes) + num_classes = self.cfg.MODEL.ROI_HEADS.NUM_CLASSES + # +1 because fast_rcnn_inference expects background scores as well + all_scores_2d = torch.zeros(num_boxes, num_classes + 1, device=all_boxes.device) + for idx, cls, score in zip(count(), all_classes, all_scores): + all_scores_2d[idx, cls] = score + + merged_instances, _ = fast_rcnn_inference_single_image( + all_boxes, + all_scores_2d, + shape_hw, + 1e-8, + self.cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST, + self.cfg.TEST.DETECTIONS_PER_IMAGE, + ) + + return merged_instances + + def _rescale_detected_boxes(self, augmented_inputs, merged_instances, tfms): + augmented_instances = [] + for input, tfm in zip(augmented_inputs, tfms): + # Transform the target box to the augmented image's coordinate space + pred_boxes = merged_instances.pred_boxes.tensor.cpu().numpy() + pred_boxes = torch.from_numpy(tfm.apply_box(pred_boxes)) + + aug_instances = Instances( + image_size=input["image"].shape[1:3], + pred_boxes=Boxes(pred_boxes), + pred_classes=merged_instances.pred_classes, + scores=merged_instances.scores, + ) + augmented_instances.append(aug_instances) + return augmented_instances + + def _reduce_pred_masks(self, outputs, tfms): + # Should apply inverse transforms on masks. + # We assume only resize & flip are used. pred_masks is a scale-invariant + # representation, so we handle flip specially + for output, tfm in zip(outputs, tfms): + if any(isinstance(t, HFlipTransform) for t in tfm.transforms): + output.pred_masks = output.pred_masks.flip(dims=[3]) + all_pred_masks = torch.stack([o.pred_masks for o in outputs], dim=0) + avg_pred_masks = torch.mean(all_pred_masks, dim=0) + return avg_pred_masks