I am working on an algorithm for detecting keys and keyboard body from an image. The model of the keyboard is known, so a blender env. has been created to generate images with random lighting, angle, textures and objects on screen for training. I have chosen an approach utilising a UNet with the following structure:
class RELUConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
layers = [
nn.Conv2d(in_ch, out_ch,3,1,1),
nn.BatchNorm2d(out_ch),
nn.ReLU()
]
self.model = nn.Sequential(*layers)
def forward(self,x):
return self.model(x)
class DownBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
layers = [
RELUConvBlock(in_ch, out_ch),
RELUConvBlock(out_ch, out_ch)
]
self.model = nn.Sequential(*layers)
def forward(self,x):
return self.model(x)
class UpBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
layers = [
RELUConvBlock(in_ch, out_ch),
RELUConvBlock(out_ch, out_ch)
]
self.model = nn.Sequential(*layers)
def forward(self,x):
return self.model(x)
class UNet(nn.Module):
def __init__(self, out_ch = 3, down_ch = [64,128,256,512]):
super().__init__()
self.pool = nn.MaxPool2d(kernel_size=(2,2), stride=(2,2))
self.down0 = DownBlock(3, down_ch[0])
self.down1 = DownBlock(down_ch[0], down_ch[1])
self.down2 = DownBlock(down_ch[1], down_ch[2])
self.down3 = DownBlock(down_ch[2], down_ch[3])
self.bottleneck = DownBlock(down_ch[3], 2*down_ch[3])
self.up3 = UpBlock(2*down_ch[-1], down_ch[-1])
self.up2 = UpBlock(down_ch[-1], down_ch[-2])
self.up1 = UpBlock(down_ch[-2], down_ch[-3])
self.up0 = UpBlock(down_ch[-3], down_ch[-4])
self.connect_b_up3 = nn.ConvTranspose2d(2*down_ch[-1], down_ch[-1],kernel_size=2,stride=2)
self.connect_up3_up2 = nn.ConvTranspose2d(down_ch[-1], down_ch[-2],kernel_size=2,stride=2)
self.connect_up2_up1 = nn.ConvTranspose2d(down_ch[-2], down_ch[-3],kernel_size=2,stride=2)
self.connect_up1_up0 = nn.ConvTranspose2d(down_ch[-3], down_ch[-4],kernel_size=2,stride=2)
self.final_conv = nn.Conv2d(down_ch[0], out_ch, kernel_size=1)
def _crop_to_match(self, tensor, target):
_, _, h, w = target.shape
return tensor[:, :, :h, :w]
def forward(self, x):
skip_connections = []
x = self.down0(x)
skip_connections.append(x)
x = self.pool(x)
x = self.down1(x)
skip_connections.append(x)
x = self.pool(x)
x = self.down2(x)
skip_connections.append(x)
x = self.pool(x)
x = self.down3(x)
skip_connections.append(x)
x = self.pool(x)
x = self.bottleneck(x)
x = self.connect_b_up3(x)
skip_connection = self._crop_to_match(skip_connections[3], x)
x = torch.cat((skip_connection, x), dim=1)
x = self.up3(x)
x = self.connect_up3_up2(x)
skip_connection = self._crop_to_match(skip_connections[2], x)
x = torch.cat((skip_connection, x), dim=1)
x = self.up2(x)
x = self.connect_up2_up1(x)
skip_connection = self._crop_to_match(skip_connections[1], x)
x = torch.cat((skip_connection, x), dim=1)
x = self.up1(x)
x = self.connect_up1_up0(x)
skip_connection = self._crop_to_match(skip_connections[0], x)
x = torch.cat((skip_connection, x), dim=1)
x = self.up0(x)
x = self.final_conv(x)
return x #nn.functional.sigmoid(x)
The network is trained on 500 images generated from the blender env.
LEARNING_RATE = 1e-4
num_epochs = 10
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
scaler = torch.amp.GradScaler(device)
model.train()
for epoch in range(num_epochs):
loop = tqdm(enumerate(train_loader), total = len(train_loader))
for batch_idx, (data, targets) in loop:
data = data.to(device)
targets = targets.to(device)
with torch.amp.autocast(device_type=str(device)):
predictions = model(data)
loss = loss_fn(predictions, targets)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
loop.set_postfix(loss = loss.item())
While the results on training and testing data are very promising (results_1), when tested on real photos the network performs poorly (results_2). What can be done to fix that? Is there any way other than creating a better blender env? Maybe a different type of nn or a different solution entirely? This is my first time using AI to solve real world problem so i probably made a lot of bad choices.
I have tried to tweak the parameters, make the blender images darker to reflect reality better and using openCV to manipulate with the results but none of that worked well