catbook DeokJin about AI

Self-Distillation for Improving CTC-Transformer-based ASR Systems (ISCA 2020)

» paper

Interspeech 2020에 개제된 논문이다.

Abstract

Sequence-to-sequence(S2S) 모델은 안정적으로 ASR 시스템에 정착되었고, 핵심 요소는 attention 메커니즘이 입력과 아웃풋인 캐릭터를 정렬(alignment)하는 것이다. 어텐션의 가중치는 입력과 출력의 시퀀스의 타임 프레임을 알려준다. 본 연구에서는 매 시점의 트랜스포머의 출력과 어텐션의 가중치를 사용하여 pseudo-targets을 생성하고, CTC-Transformer의 shared encoder를 Transformer-decoder를 통해 피드백 받아 더 많은 정보를 학습한다. 증류 기법은 흔하지만, 이를 디코더로 부터 피드백을 받는 방법에 대해서는 감이 안와서 method를 더 세밀하게 분석할 필요가 있어보인다. 이런 방식으로 모델을 학습하면 S2S 역시 공유 인코더로 인해서 성능이 강화된다고 한다.

ASR Modules

baseline은 Transformer와 CTC이다. Transformer의 output은 문장의 끝을 나타내는 또는 토큰을 포함한다.

Transformer

multi-head- attention (MHA)

트랜스포머는 인코더와 디코더 블록으로 구성되있다. 각 블록은 scale dot-product attention을 가지고 있음. 최근 등장한 모델은 대부분 attention 메커니즘을 사용하기 때문에 여기서 구현과 함께 공부해보도록 한다.

image

class ScaledDotProductAttention(nn.Module):

    def forward(self, query, key, value, mask=None):
        dk = query.size()[-1]
        scores = query.matmul(key.transpose(-2, -1)) / math.sqrt(dk)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attention = F.softmax(scores, dim=-1)
        return attention.matmul(value)

Attention의 내부 구조는 query와 key의 transpose를 내적하고, softmax를 취한후 attention score를 구하고 value와 곱하여 context를 구하는 연산이다. multi-head attention은 각각의 Q, K, V를 head 개수만큼 분리하여 각각을 따로 연산한다. 보통 attention을 사용할 때, 파라미터로 Q, K, V 그리고 mask를 받게 된다. mask는 가중치 값을 -inf와 같이 매우 큰 음수 값으로 보내버려서 softmax에서 0을 얻게 한다. Transformer에서는 square subsequent mask와 key padding mask가 있다. 두개의 차이점은 시퀀스에 대응을 하느냐 아니면 batch에 대응을 하느냐이다.

class MultiHeadAttention(nn.Module):

    def __init__(self,
                 in_features,
                 head_num,
                 bias=True,
                 activation=F.relu):
        """Multi-head attention.
        :param in_features: Size of each input sample.
        :param head_num: Number of heads.
        :param bias: Whether to use the bias term.
        :param activation: The activation after each linear transformation.
        """
        super(MultiHeadAttention, self).__init__()
        if in_features % head_num != 0:
            raise ValueError('`in_features`({}) should be divisible by `head_num`({})'.format(in_features, head_num))
        self.in_features = in_features
        self.head_num = head_num
        self.activation = activation
        self.bias = bias
        self.linear_q = nn.Linear(in_features, in_features, bias)
        self.linear_k = nn.Linear(in_features, in_features, bias)
        self.linear_v = nn.Linear(in_features, in_features, bias)
        self.linear_o = nn.Linear(in_features, in_features, bias)

    def forward(self, q, k, v, mask=None):
        q, k, v = self.linear_q(q), self.linear_k(k), self.linear_v(v)
        if self.activation is not None:
            q = self.activation(q)
            k = self.activation(k)
            v = self.activation(v)

        q = self._reshape_to_batches(q)
        k = self._reshape_to_batches(k)
        v = self._reshape_to_batches(v)
        if mask is not None:
            mask = mask.repeat(self.head_num, 1, 1)
        y = ScaledDotProductAttention()(q, k, v, mask)
        y = self._reshape_from_batches(y)

        y = self.linear_o(y)
        if self.activation is not None:
            y = self.activation(y)
        return y

    @staticmethod
    def gen_history_mask(x):
        """Generate the mask that only uses history data.
        :param x: Input tensor.
        :return: The mask.
        """
        batch_size, seq_len, _ = x.size()
        return torch.tril(torch.ones(seq_len, seq_len)).view(1, seq_len, seq_len).repeat(batch_size, 1, 1)

    def _reshape_to_batches(self, x):
        batch_size, seq_len, in_feature = x.size()
        sub_dim = in_feature // self.head_num
        return x.reshape(batch_size, seq_len, self.head_num, sub_dim)\
                .permute(0, 2, 1, 3)\
                .reshape(batch_size * self.head_num, seq_len, sub_dim)

    def _reshape_from_batches(self, x):
        batch_size, seq_len, in_feature = x.size()
        batch_size //= self.head_num
        out_dim = in_feature * self.head_num
        return x.reshape(batch_size, self.head_num, seq_len, in_feature)\
                .permute(0, 2, 1, 3)\
                .reshape(batch_size, seq_len, out_dim)

    def extra_repr(self):
        return 'in_features={}, head_num={}, bias={}, activation={}'.format(
            self.in_features, self.head_num, self.bias, self.activation,
        )

입력된 Q, K, V와 동일한 가중치 매트릭스를 각각 연산하고, head의 개수만큼 벡터를 조각낸다. 그 후 위에서 attention 연산을 그대로 적용하면 각 헤드끼리 연산이 되고, 최종적으로 다시 head를 하나의 벡터로 선형 결합하는 구조이다.

Speech encoder of Transformer

가장 초기의 입력을 Conv2D로 피처 맵을 형성하고 linear embedding 한다. 이 과정을 subsampling이라고 부르는데 피처 맵의 압축이라고 이해하였다.

image

임베딩된 결과는 각 time step에 따라서 위치적으로 인코딩 되고, positional encoding(PE)은 Wav2Vec 2.0과 다르게 sinusoidal positional encoding을 사용하였고, 임베딩 된 각각의 time step에 따라서 아래의 수식과 같이 위치 인코딩 된다. 이에 대한 설명과 구현은 Positional Encoder으로 따로 정리를 해야겠다.

image

class PositionalEmbedding(nn.Module):

    def __init__(self, d_model, max_len=512):
        super().__init__()

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe[:, :x.size(1)]