You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I encountered an issue while fine-tuning the model on a custom dataset. After completing the fine-tuning, I am unable to achieve deterministic results when running the model in "inference mode" by setting is_training=False and model.eval().
Below is a reproducible code snippet, which is very similar to the fine-tuning notebook in the main repository:
from transformers import AutoProcessor, AutoModelForUniversalSegmentation
from PIL import Image
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
import requests
import numpy as np
import torch
# Loading model
processor = AutoProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
model = AutoModelForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny", is_training=True)
processor.image_processor.num_text = model.config.num_queries - model.config.text_encoder_n_ctx
class CustomDataset(Dataset):
def __init__(self, processor):
self.processor = processor
def __len__(self):
return 1
def __getitem__(self, idx):
url = "https://user-images.githubusercontent.com/590151/281234915-de8071bf-0e98-44be-ba9e-d9c9642c704f.jpg"
image = Image.open(requests.get(url, stream=True).raw)
# load semantic segmentation map, which labels every pixel
url = "https://user-images.githubusercontent.com/590151/281234913-9ae307f0-6b57-4a4b-adf6-bed73390c00d.png"
map = Image.open(requests.get(url, stream=True).raw)
map = np.array(map)
# use processor to convert this to a list of binary masks, labels, text inputs and task inputs
inputs = self.processor(images=image, segmentation_maps=map, task_inputs=["semantic"], return_tensors="pt")
inputs = {k:v.squeeze() if isinstance(v, torch.Tensor) else v[0] for k,v in inputs.items()}
return inputs
# Create dataset
dataset = CustomDataset(processor)
# Create dataloader
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
optimizer = AdamW(model.parameters(), lr=5e-5)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.train()
model.to(device)
for epoch in range(10):
for batch in dataloader:
optimizer.zero_grad()
batch = {k:v.to(device) for k,v in batch.items()}
# forward pass
outputs = model(**batch)
# backward pass + optimize
loss = outputs.loss
print("Loss:", loss.item())
loss.backward()
optimizer.step()
However, when passing the same image twice through the network, the logits differ slightly.
I believe the issue originates from the dropout layers within the OneFormerPixelDecoderEncoderLayer. The self.is_training parameter is directly tied to the OneFormerConfig, and thus, setting self.is_training=False later does not reflect in this module.
The text was updated successfully, but these errors were encountered:
I encountered an issue while fine-tuning the model on a custom dataset. After completing the fine-tuning, I am unable to achieve deterministic results when running the model in "inference mode" by setting is_training=False and model.eval().
Below is a reproducible code snippet, which is very similar to the fine-tuning notebook in the main repository:
However, when passing the same image twice through the network, the logits differ slightly.
I believe the issue originates from the dropout layers within the
OneFormerPixelDecoderEncoderLayer
. Theself.is_training
parameter is directly tied to theOneFormerConfig
, and thus, settingself.is_training=False
later does not reflect in this module.The text was updated successfully, but these errors were encountered: