Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-deterministic Inference Results After Fine-tuning OneFormer Model #122

Open
PanzaResce opened this issue Sep 21, 2024 · 0 comments
Open

Comments

@PanzaResce
Copy link

I encountered an issue while fine-tuning the model on a custom dataset. After completing the fine-tuning, I am unable to achieve deterministic results when running the model in "inference mode" by setting is_training=False and model.eval().

Below is a reproducible code snippet, which is very similar to the fine-tuning notebook in the main repository:

from transformers import AutoProcessor, AutoModelForUniversalSegmentation
from PIL import Image
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
import requests
import numpy as np
import torch

# Loading model
processor = AutoProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
model = AutoModelForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny", is_training=True)

processor.image_processor.num_text = model.config.num_queries - model.config.text_encoder_n_ctx

class CustomDataset(Dataset):  
    def __init__(self, processor):
        self.processor = processor

    def __len__(self):
        return 1  
    
    def __getitem__(self, idx):
        url = "https://user-images.githubusercontent.com/590151/281234915-de8071bf-0e98-44be-ba9e-d9c9642c704f.jpg"
        image = Image.open(requests.get(url, stream=True).raw)

        # load semantic segmentation map, which labels every pixel
        url = "https://user-images.githubusercontent.com/590151/281234913-9ae307f0-6b57-4a4b-adf6-bed73390c00d.png"
        map = Image.open(requests.get(url, stream=True).raw)
        map = np.array(map)

        # use processor to convert this to a list of binary masks, labels, text inputs and task inputs
        inputs = self.processor(images=image, segmentation_maps=map, task_inputs=["semantic"], return_tensors="pt")
        inputs = {k:v.squeeze() if isinstance(v, torch.Tensor) else v[0] for k,v in inputs.items()}

        return inputs

# Create dataset
dataset = CustomDataset(processor)

# Create dataloader
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

optimizer = AdamW(model.parameters(), lr=5e-5)
device = "cuda" if torch.cuda.is_available() else "cpu"

model.train()
model.to(device)
for epoch in range(10): 
    for batch in dataloader:
        optimizer.zero_grad()
        batch = {k:v.to(device) for k,v in batch.items()}

        # forward pass
        outputs = model(**batch)

        # backward pass + optimize
        loss = outputs.loss
        print("Loss:", loss.item())
        loss.backward()
        optimizer.step()

However, when passing the same image twice through the network, the logits differ slightly.

model.eval()
model.model.is_training = False

# load image
url = "https://user-images.githubusercontent.com/590151/281234915-de8071bf-0e98-44be-ba9e-d9c9642c704f.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt")
inputs = {k:v.to(device) for k,v in inputs.items()}

with torch.no_grad():
  outputs = model(**inputs)
  outputs_2 = model(**inputs)

print(torch.all(outputs.class_queries_logits == outputs_2.class_queries_logits))

I believe the issue originates from the dropout layers within the OneFormerPixelDecoderEncoderLayer. The self.is_training parameter is directly tied to the OneFormerConfig, and thus, setting self.is_training=False later does not reflect in this module.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant