Implementing and Experimenting with the Multiscale Multi-head Self-attention Ensemble Network for cancer detection

Introduction

Breast cancer is the most diagnosed cancer and one of the leading cancer-related deaths among women. To obtain definitive diagnosis a probe is extracted from the patient´s tissue. Through a process called digital pathology scanning, recording the tissue at a very high resolution, Whole Slide Images (WSIs) are acquired. Skilled physicians are able to make a diagnosis based on examining these images. By leveraging deep learning techniques this process can be assisted or even automated. To this end R. Ge et al. propose the Multiscale Multi-head Self-attention Ensemble Network. It is a heterogeneous deep ensemble learning approach. The intermediate feature vectors produced by VGG16 and DenseNet121, pretrained on the ImageNet1k dataset, are combined using a self-attention layer, followed by gloabal avearge pooling. Deep ensemble learning techniques show better generalization capabilities in general.

The model is trained on the PCam benchmark dataset, a binary classification task, images are labeled as 1 if cancerous tissue is present, 0 else. This dataset comprises 400 WSIs from Radboud University Medical Center (RUMC) and University Medical Center Utrecht (UMCU). The dataset underwent expert pathological analysis for the extraction and labeling of diagnostic patches.

Example scientific figure
Figure 1: A reusable figure component with a caption.

Model Architecture

Model Training and Evaluation

As a loss function Binary Cross Entropy is employed (BCE). Given a model ff and data points {(xi,yi)}i=0n\{(x_i, y_i)\}_{i=0}^n, the loss function is computed as follows

lBCE=i=1nyilog(f(xi))+(1yi)log(1f(xi)) .l_{BCE} = \sum_{i = 1}^n y_i \cdot \log(f(x_i)) + (1 - y_i) \cdot \log(1 - f(x_i)) ~.

To update the weights of the model the Adam optimizer is employed and gradients are computed using backpropagation. The authors propose a custom learning rate schedule

η(t)={η0 ,  t=1ηmin+12(η0ηmin)(1+cos(πt1T1))\eta(t) = \begin{cases} \eta_0 ~, ~~ t=1 \\ \eta_{min} + \frac{1}{2} (\eta_0 - \eta_{min}) \left( 1 + \cos(\pi \frac{t - 1}{T - 1})\right) \end{cases}

where η0=0.001\eta_0 = 0.001, ηmin=0.00001\eta_{min} = 0.00001, tt is the current epoch and TT is the overall number of epochs.

In order to assess the performance of the model, 7 metrics are employed. Critical for the calculation of these metrics are the following four quantities.

1PNpPnN1f(p)>f(n)+121f(p)=f(n)\frac{1}{|P||N|} \sum_{p \in P} \sum_{n \in N} 1_{f(p) > f(n)} + \frac{1}{2} \cdot 1_{f(p) = f(n)}

Precision measures the proportion of positive predictions that are actually correct

TPTP+FP .\frac{TP}{TP + FP} ~.

Sensitivity,

TPTP+FN\frac{TP}{TP + FN}

Specifity,

TNTN+FP\frac{TN}{TN + FP}

F1 - score,

2PrecisionSensitivityPrecision+Sensitivity2 \cdot \frac{Precision \cdot Sensitivity}{Precision + Sensitivity}

B - acc,

12(Sensitivity+Specifity) \frac{1}{2} (Sensitivity + Specifity)

MCC,

TPTNFPFN(TP+FP)(TP+FN)(TN+FP)(TN+FN)\frac{TP \cdot TN - FP \cdot FN}{\sqrt{(TP + FP)(TP + FN)(TN + FP)(TN + FN)}}