Yêu cầu thg 12 5, 2023 4:41 SA 102 0 1
  • 102 0 1
+1

Lỗi output size không giống với input size trong hàm training trong bài toán semantic segmentation

Chia sẻ
  • 102 0 1

E chào các anh chị, e đang làm 1 project nhỏ liên quan tới semantic segmentation cho đối tượng là dây safety wire. Mọi data như ảnh và anotation e đã có, nhưng khi chạy hàm training thì bị vướng khi xuất hiện lỗi output size không giống với input size, hiện nay e vẫn chưa biết cách khắc phục, mong các anh chị có thể chỉ dẫn giúp em ạ. E cảm ơn anh chị rất nhiều!

for ep in range(1, 1+n_eps):
    acc_meter.reset()
    train_loss_meter.reset()
    dice_meter.reset()
    iou_meter.reset()
    model.train()

    for batch_id, (x, y) in enumerate(tqdm(trainloader), start=1):
        optimizer.zero_grad()
        n = x.shape[0]     #kích thước batch_size
        x = x.to(device).float()
        y = y.to(device).float()
        print("y: ")
        print(y)
        y_hat = model(x)    #y predict
        y_hat = y_hat.squeeze()
        #print("y_hat: ")
        #print(y_hat)
        #if y == y_hat:
          loss = criterion(y_hat, y)
          loss.backward()

        optimizer.step()

        with torch.no_grad():
            y_hat_mask = y_hat.sigmoid().round().long() # -> mask (0, 1)
            dice_score = dice_fn(y_hat_mask, y.long())
            iou_score = iou_fn(y_hat_mask, y.long())
            accuracy = accuracy_function(y_hat_mask, y.long())

            train_loss_meter.update(loss.item(), n)
            iou_meter.update(iou_score.item(), n)
            dice_meter.update(dice_score.item(), n)
            acc_meter.update(accuracy.item(), n)

    print("EP {}, train loss = {}, accuracy = {}, IoU = {}, dice = {}".format(
        ep, train_loss_meter.avg, acc_meter.avg, iou_meter.avg, dice_meter.avg
    ))
    if ep >= 25:
        torch.save(model.state_dict(), "/content/model_ep_{}.pth".format(ep))

1 CÂU TRẢ LỜI


Đã trả lời thg 12 5, 2023 7:53 SA
Đã được chấp nhận
+2

Lỗi "output size không giống với input size" thường xảy ra khi kích thước đầu ra của một lớp hoặc mạng khác với kích thước đầu vào. Điều này có thể xảy ra do một số nguyên nhân, chẳng hạn như:

Chèn hoặc cắt xén không chính xác: Đảm bảo rằng các thao tác chèn hoặc cắt xén được áp dụng chính xác và không thay đổi kích thước đầu ra mong đợi. Chưa có hoặc upsampling không chính xác: Nếu bạn đang sử dụng các lớp upsampling để tăng độ phân giải không gian của đầu ra, hãy đảm bảo các hệ số upsampling và phương pháp nội suy phù hợp. Các hàm kích hoạt không tương thích: Nếu bạn đang sử dụng các hàm kích hoạt làm thay đổi kích thước đầu ra, chẳng hạn như Sigmoid hoặc Tanh, hãy đảm bảo chúng được áp dụng chính xác và không thay đổi kích thước đầu ra mong đợi. Hướng dẫn cụ thể cho mã của bạn

Trong mã của bạn, lỗi có thể xảy ra do thao tác squeeze() sau y_hat = model(x). Thao tác squeeze() loại bỏ các chiều đơn lẻ khỏi tensor, điều này có thể thay đổi kích thước đầu ra mong đợi.

Để khắc phục vấn đề này, bạn có thể thử một trong các phương pháp sau:

Xóa thao tác squeeze(): Thao tác này sẽ giữ nguyên kích thước đầu ra của y_hat. Thay thế thao tác squeeze() bằng view(): Thao tác view() cho phép bạn sửa đổi hình dạng của tensor. Bạn có thể sử dụng view() để giữ nguyên kích thước đầu ra của y_hat. Sử dụng các kỹ thuật giảm chiều: Các kỹ thuật giảm chiều như AdaptiveMaxPool2d hoặc AdaptiveAvgPool2d có thể điều chỉnh kích thước đầu ra của y_hat để khớp với hình dạng đầu vào. Ví dụ

Giả sử kích thước đầu vào của x là (1, 3, 224, 224) và kích thước đầu ra mong đợi của y_hat là (1, 2, 224, 224).

Nếu bạn xóa thao tác squeeze(), kích thước đầu ra của y_hat sẽ vẫn là (1, 3, 224, 224).

Nếu bạn thay thế thao tác squeeze() bằng view(), bạn có thể sử dụng mã sau để giữ nguyên kích thước đầu ra:

y_hat = y_hat.view(1, 2, 224, 224) Nếu bạn sử dụng các kỹ thuật giảm chiều, bạn có thể sử dụng mã sau để điều chỉnh kích thước đầu ra:

y_hat = AdaptiveMaxPool2d((2, 2))(y_hat) Mã này sẽ sử dụng lớp AdaptiveMaxPool2d với kích thước pooling là (2, 2) để giảm kích thước đầu ra của y_hat xuống (1, 2, 224, 224).

Chia sẻ
Viblo
Hãy đăng ký một tài khoản Viblo để nhận được nhiều bài viết thú vị hơn.
Đăng kí