Giter Club home page Giter Club logo

Comments (2)

moon6666 avatar moon6666 commented on June 25, 2024

也可以直接在mmdetection的mmdet/datasets/dataset_wrappers.py中添加一个wrapper,参考我写的,这样就可以方便调用所有mmdetection中实现的数据增强了。
@DATASETS.register_module()
class SeqMultiImageMixDataset:

def __init__(self,
             dataset,
             pipeline,
             dynamic_scale=None,
             skip_type_keys=None,
             max_refetch=15):
    if dynamic_scale is not None:
        raise RuntimeError(
            'dynamic_scale is deprecated. Please use Resize pipeline '
            'to achieve similar functions')
    assert isinstance(pipeline, collections.abc.Sequence)
    if skip_type_keys is not None:
        assert all([
            isinstance(skip_type_key, str)
            for skip_type_key in skip_type_keys
        ])
    self._skip_type_keys = skip_type_keys

    self.pipeline = []
    self.pipeline_types = []
    for transform in pipeline:
        if isinstance(transform, dict):
            self.pipeline_types.append(transform['type'])
            transform = build_from_cfg(transform, PIPELINES)
            self.pipeline.append(transform)
        else:
            raise TypeError('pipeline must be a dict')

    self.dataset = dataset
    self.CLASSES = dataset.CLASSES
    self.PALETTE = getattr(dataset, 'PALETTE', None)
    if hasattr(self.dataset, 'flag'):
        self.flag = dataset.flag
    self.num_samples = len(dataset)
    self.max_refetch = max_refetch

def __len__(self):
    return self.num_samples

def __getitem__(self, idx):
    results = copy.deepcopy(self.dataset[idx])
    for (transform, transform_type) in zip(self.pipeline,
                                           self.pipeline_types):
        if self._skip_type_keys is not None and \
                transform_type in self._skip_type_keys:
            continue

        # 1.需要多帧拼接的pipline提取拼接帧
        if hasattr(transform, 'get_indexes'):
            indexes = transform.get_indexes(self.dataset)
            # 将序列按顺序分放在不同的帧中
            for j in range(len(results)):
                if not isinstance(indexes, collections.abc.Sequence):
                    indexes = [indexes]
                mix_results = [
                    copy.deepcopy(self.dataset[index][j]) for index in indexes
                ]
                if None not in mix_results:
                    results[j]['mix_results'] = mix_results
                    # print(results[j]['img_info'])
                    # print(mix_results[0]['img_info'])
                    # print(mix_results[1]['img_info'])
                    # print(mix_results[2]['img_info'])

        # 2.将准备好的帧进行transform
        if transform_type == 'VideoCollect' or transform_type == 'ConcatVideoReferences' or transform_type == 'SeqDefaultFormatBundle':
            updated_results = transform(copy.deepcopy(results))
            if updated_results is not None:
                results = updated_results
        else:
            p = random.uniform(0, 1)
            if (transform_type == 'Mosaic' or transform_type == 'RandomAffine' or transform_type == 'MixUp'
                or transform_type == 'Expand' or transform_type == 'PhotoMetricDistortion'or transform_type == 'YOLOXHSVRandomAug') and p > 0.8:
                #print("nononononono")
                continue
            seed = random.randint(1, 100)
            for j in range(len(results)):
                results[j]['seed'] = seed
                updated_results = transform(copy.deepcopy(results[j]))
                if updated_results is not None:
                    results[j] = updated_results

    return results

def update_skip_type_keys(self, skip_type_keys):
    """Update skip_type_keys. It is called by an external hook.

    Args:
        skip_type_keys (list[str], optional): Sequence of type
            string to be skip pipeline.
    """
    assert all([
        isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
    ])
    self._skip_type_keys = skip_type_keys

from mmtracking.

moon6666 avatar moon6666 commented on June 25, 2024

抱歉这个函数是我自己添加的,是我自己写错了。

from mmtracking.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.