-
Notifications
You must be signed in to change notification settings - Fork 134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error when training for instance segmentation with a custom dataset #17
Comments
Hi @Robotatron, please find the answers to your questions below:
You need to pay attention to two files here: You encounter the OneFormer/oneformer/data/dataset_mappers/coco_unified_new_baseline_dataset_mapper.py Line 306 in 33ebb56
So, because you are using a custom dataset for instance segmentation, you should ensure the following two steps:
It's pretty easy to ensure these, and it should take you only a short time. Remember, because you only care about instance segmentation, you need to define the thing_classes correctly in your |
@praeclarumjj3 Thanks for the answer. The dataset mapping is a new thing to me, please have patience with me :)
1.I am registering the dataset with as I did with Mask2Former with 2.
Are you talking about these attributes?
If so, what do I put under "sem_seg" since I dont have semantic segmentation labels, just an empty list? |
I think I am close, I've tried using the dataset mapper from Mask2Former : I've modified the Mask2Former dataset mapper like this (mainly at the end when setting the attributes of the Copyright (c) Facebook, Inc. and its affiliates.import copy import numpy as np from detectron2.config import configurable all = ["MaskFormerInstanceDatasetMapper"] class MaskFormerInstanceDatasetMapper:
It almost works, the model gets initialized, the dataloader provides images: But when starting training it break: Full error:
/opt/conda/envs/oneformer2/lib/python3.8/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.
warnings.warn(warning.format(ret))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[20], line 1
----> 1 trainer.train()
File /opt/conda/envs/oneformer2/lib/python3.8/site-packages/detectron2/engine/defaults.py:484, in DefaultTrainer.train(self) File /opt/conda/envs/oneformer2/lib/python3.8/site-packages/detectron2/engine/train_loop.py:149, in TrainerBase.train(self, start_iter, max_iter) File /opt/conda/envs/oneformer2/lib/python3.8/site-packages/detectron2/engine/defaults.py:494, in DefaultTrainer.run_step(self) File /opt/conda/envs/oneformer2/lib/python3.8/site-packages/detectron2/engine/train_loop.py:395, in AMPTrainer.run_step(self) File /opt/conda/envs/oneformer2/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs) File ~/OneFormer/oneformer/oneformer_model.py:296, in OneFormer.forward(self, batched_inputs) File /opt/conda/envs/oneformer2/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs) File ~/OneFormer/oneformer/modeling/criterion.py:294, in SetCriterion.forward(self, outputs, targets) File ~/OneFormer/oneformer/modeling/criterion.py:268, in SetCriterion.get_loss(self, loss, outputs, targets, indices, num_masks) File ~/OneFormer/oneformer/modeling/criterion.py:152, in SetCriterion.loss_contrastive(self, outputs, targets, indices, num_masks) File /opt/conda/envs/oneformer2/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py:822, in get_rank(group) File /opt/conda/envs/oneformer2/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py:410, in _get_default_group() RuntimeError: Default process group has not been initialized, please make sure to call init_process_group. |
Hi @Robotatron, you can refer to the following code to understand how to set the OneFormer/oneformer/data/dataset_mappers/coco_unified_new_baseline_dataset_mapper.py Lines 197 to 225 in 1780582
All you need to do is to loop through the class ids and collect the corresponding class names. You do not need to worry about the You encounter the |
Thanks for the reply @praeclarumjj3! I've looked at your linked issue (facebookresearch/detectron2#3972) , they recommend to replace "SyncBN" batch norm with a simple "BN". However it seems OneFormer does not use "SyncBN", it uses group norm "GN". Have you ever tried training OneFormer on a single machine with a single GPU or do you know if it is supported at all? |
Hi @Robotatron, after taking a closer look at your error, I believe the error corresponds to the To train on a single GPU, please replace the contrastive loss method with the code below(with a check if the distributed process group has been initialized or not). I have also updated the def loss_contrastive(self, outputs, targets, indices, num_masks):
assert "contrastive_logits" in outputs
assert "texts" in outputs
image_x = outputs["contrastive_logits"].float()
batch_size = image_x.shape[0]
# get label globally
if is_dist_avail_and_initialized():
labels = torch.arange(batch_size, dtype=torch.long, device=image_x.device) + batch_size * dist.get_rank()
else:
labels = torch.arange(batch_size, dtype=torch.long, device=image_x.device)
text_x = outputs["texts"]
# [B, C]
image_x = F.normalize(image_x.flatten(1), dim=-1)
text_x = F.normalize(text_x.flatten(1), dim=-1)
if is_dist_avail_and_initialized():
logits_per_img = image_x @ dist_collect(text_x).t()
logits_per_text = text_x @ dist_collect(image_x).t()
else:
logits_per_img = image_x @ text_x.t()
logits_per_text = text_x @ image_x.t()
logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
loss_img = self.cross_entropy(logits_per_img * logit_scale, labels)
loss_text = self.cross_entropy(logits_per_text * logit_scale, labels)
loss_contrastive = loss_img + loss_text
losses = {"loss_contrastive": loss_contrastive}
return losses |
@praeclarumjj3 Thanks for updating the repo, the training now works on a single GPU <3 Last question with regards to the "text" attribute, you said earlier:
1. Is dataset_dict["text"] used for training?I am not really sure what the purpose of this is, it seems to be a list of string messages of the format I took your code for setting up the "texts" attribute and tried to adjust it for the instance segmentation Mask2Former dataset mapper. I am stuck at the following two lines in your code (196 and 197): https://github.com/SHI-Labs/OneFormer/blob/main/oneformer/data/dataset_mappers/coco_unified_new_baseline_dataset_mapper.py#L196 A)There is no
B)But then I was not sure what mask to use. I don't have
and the code for setting up the "text" attribute looks like this (on the right, your original code on the left): 2.Can the text attribute be populated like in the image above or should I still use the "mask" variable with the if statement? |
@Robotatron , are you able to train oneformer for instance segmentation. |
Yes it works and trains, maybe I should submit a PR with a modified Mask2Former data mapper for instance segmentation if @praeclarumjj3 is OK with it |
Hi @Robotatron, thanks for the offer. As many people have been opening issues regarding the custom training, I will push some custom |
Hi @Robotatron, yes. we use Also, you may use the following script to train OneFormer on a custom instance segmentation dataset. Remember you will need to register your dataset's metadata with the class names. import copy
import logging
import numpy as np
import torch
from detectron2.data import MetadataCatalog
from detectron2.config import configurable
from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
from oneformer.data.tokenizer import SimpleTokenizer, Tokenize
from pycocotools import mask as coco_mask
__all__ = ["InstanceCOCOCustomNewBaselineDatasetMapper"]
def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
if masks:
masks = torch.stack(masks, dim=0)
else:
masks = torch.zeros((0, height, width), dtype=torch.uint8)
return masks
def build_transform_gen(cfg, is_train):
"""
Create a list of default :class:`Augmentation` from config.
Now it includes resizing and flipping.
Returns:
list[Augmentation]
"""
assert is_train, "Only support training augmentation"
image_size = cfg.INPUT.IMAGE_SIZE
min_scale = cfg.INPUT.MIN_SCALE
max_scale = cfg.INPUT.MAX_SCALE
augmentation = []
if cfg.INPUT.RANDOM_FLIP != "none":
augmentation.append(
T.RandomFlip(
horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal",
vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
)
)
augmentation.extend([
T.ResizeScale(
min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size
),
T.FixedSizeCrop(crop_size=(image_size, image_size)),
])
return augmentation
# This is specifically designed for the COCO Instance Segmentation dataset.
class InstanceCOCOCustomNewBaselineDatasetMapper:
"""
A callable which takes a dataset dict in Detectron2 Dataset format,
and map it into a format used by OneFormer for custom instance segmentation using COCO format.
The callable currently does the following:
1. Read the image from "file_name"
2. Applies geometric transforms to the image and annotation
3. Find and applies suitable cropping to the image and annotation
4. Prepare image and annotation to Tensors
"""
@configurable
def __init__(
self,
is_train=True,
*,
num_queries,
tfm_gens,
meta,
image_format,
max_seq_len,
task_seq_len,
):
"""
NOTE: this interface is experimental.
Args:
is_train: for training or inference
augmentations: a list of augmentations or deterministic transforms to apply
crop_gen: crop augmentation
tfm_gens: data augmentation
image_format: an image format supported by :func:`detection_utils.read_image`.
"""
self.tfm_gens = tfm_gens
logging.getLogger(__name__).info(
"[InstanceCOCOCustomNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(
str(self.tfm_gens)
)
)
self.img_format = image_format
self.is_train = is_train
self.meta = meta
self.ignore_label = self.meta.ignore_label
self.num_queries = num_queries
self.things = []
for k,v in self.meta.thing_dataset_id_to_contiguous_id.items():
self.things.append(v)
self.class_names = self.meta.thing_classes
self.text_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=max_seq_len)
self.task_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=task_seq_len)
@classmethod
def from_config(cls, cfg, is_train=True):
# Build augmentation
tfm_gens = build_transform_gen(cfg, is_train)
dataset_names = cfg.DATASETS.TRAIN
meta = MetadataCatalog.get(dataset_names[0])
ret = {
"is_train": is_train,
"meta": meta,
"tfm_gens": tfm_gens,
"image_format": cfg.INPUT.FORMAT,
"num_queries": cfg.MODEL.ONE_FORMER.NUM_OBJECT_QUERIES - cfg.MODEL.TEXT_ENCODER.N_CTX,
"task_seq_len": cfg.INPUT.TASK_SEQ_LEN,
"max_seq_len": cfg.INPUT.MAX_SEQ_LEN,
}
return ret
def _get_texts(self, classes, num_class_obj):
classes = list(np.array(classes))
texts = ["an instance photo"] * self.num_queries
for class_id in classes:
cls_name = self.class_names[class_id]
num_class_obj[cls_name] += 1
num = 0
for i, cls_name in enumerate(self.class_names):
if num_class_obj[cls_name] > 0:
for _ in range(num_class_obj[cls_name]):
if num >= len(texts):
break
texts[num] = f"a photo with a {cls_name}"
num += 1
return texts
def __call__(self, dataset_dict):
"""
Args:
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
Returns:
dict: a format that builtin models in detectron2 accept
"""
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
utils.check_image_size(dataset_dict, image)
# TODO: get padding mask
# by feeding a "segmentation mask" to the same transforms
padding_mask = np.ones(image.shape[:2])
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
# the crop transformation has default padding value 0 for segmentation
padding_mask = transforms.apply_segmentation(padding_mask)
padding_mask = ~ padding_mask.astype(bool)
image_shape = image.shape[:2] # h, w
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
# Therefore it's important to use torch.Tensor.
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask))
if not self.is_train:
# USER: Modify this if you want to keep them for some reason.
dataset_dict.pop("annotations", None)
return dataset_dict
if "annotations" in dataset_dict:
# USER: Modify this if you want to keep them for some reason.
for anno in dataset_dict["annotations"]:
anno.pop("keypoints", None)
# USER: Implement additional transformations if you have other types of data
annos = [
utils.transform_instance_annotations(obj, transforms, image_shape)
for obj in dataset_dict.pop("annotations")
if obj.get("iscrowd", 0) == 0
]
instances = utils.annotations_to_instances(annos, image_shape)
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
# Need to filter empty instances first (due to augmentation)
instances = utils.filter_empty_instances(instances)
# Generate masks from polygon
h, w = instances.image_size
# image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float)
if hasattr(instances, 'gt_masks'):
gt_masks = instances.gt_masks
gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)
instances.gt_masks = gt_masks
dataset_dict["instances"] = instances
num_class_obj = {}
for name in self.class_names:
num_class_obj[name] = 0
task = "The task is instance"
text = self._get_texts(instances.gt_classes, num_class_obj)
dataset_dict["instances"] = instances
dataset_dict["orig_shape"] = image_shape
dataset_dict["task"] = task
dataset_dict["text"] = text
dataset_dict["thing_ids"] = self.things
return dataset_dict |
I pushed some instructions for training with custom datasets. You may take a look if you face any more issues: https://github.com/SHI-Labs/OneFormer/tree/main/datasets/custom_datasets. |
What a legend, thanks! Will check it out when I have time later, thanks! |
Hi, @Robotatron Could you please help me how train the model on a custom dataset? Anything will be appreciated. Thanks! |
Using my custom dataset in COCO format for instance segmentation training.
Changed CFG to
Still getting an error
UnboundLocalError: local variable 'pan_seg_gt' referenced before assignment
From #5 and reading docs I understand I have to somehow prepare my dataset for instance segmentation training.
Also getting a KeyError when using that script KeyError when using detection2panoptic_coco_format cocodataset/panopticapi#58
The text was updated successfully, but these errors were encountered: