본문 바로가기

Paper

Improving drug response prediction via integrating gene relationships with deep learning

Improving drug response prediction
via integrating gene relationships with deep learning

https://academic.oup.com/bib/article/25/3/bbae153/7642699

2024, Briefings in Bioinformatics

1. 입력 피처 (Input Features)

이 구조에서 어텐션 메커니즘에 입력되는 피처는 크게 세 가지입니다.

  • Atom features: 약물의 분자 구조적 특징입니다.
  • Interactome feature: 세포 내 단백질 상호작용 네트워크 특징입니다.
  • Transcriptome feature: 세포의 유전자 발현 데이터 특징입니다.

2. 멀티 헤드 어텐션(MHA)을 통한 퓨전 메커니즘

모델은 세포 피처( Interactome  , Transcriptome )와 약물 피처( Atom ) 간의 고도화된 상호작용을 학습하기 위해 두 개의 독립적인 멀티 헤드 어텐션 레이어를 사용합니다.

이때 트랜스포머 아키텍처의 Query(Q), Key(K), Value(V) 관계는 다음과 같이 정의됩니다.

  • Interactome fusion pipeline(첫 번째 multi-head attention):
    • Q (Query): interactome fusion을 활성화 함수가 없는 선형 레이어(Linear Layer)에 통과시켜 얻습니다.
    • K (Key) & V (Value): 약물의 원자 피처를 독립된 선형 레이어에 통과시켜 얻습니다.
  • Transcriptome fusion pipeline (두 번째 multi-head attention):
    • Q (Query): Transcriptome feature를 동일하게 선형 레이어에 통과시켜 얻습니다.
    • K (Key) & V (Value): interactome 파이프라인과 마찬가지로 약물의 원자 피처로부터 얻습니다.

3. 매개변수 공유 (Parameter Sharing) 전략

텍스트에서는 이 두 파이프라인 사이의 가중치(Parameter) 공유 여부에 대한 중요한 전략을 언급하고 있습니다.

  • 일반적인 경우 (가중치 공유 ⭕): 인터랙톰 파이프라인과 전사체 파이프라인의 선형 레이어 및 MHA 레이어의 파라미터를 서로 공유합니다. 이는 모델이 서로 다른 입력 피처 간의 공통적인 패턴을 학습하게 하여, 보지 못한 새로운 세포주나 약물에 대한 일반화 능력(Generalization Ability)을 높입니다.
  • 기존 학습된 데이터 예측 시 (가중치 공유 ❌): 이미 학습된 세포주와 약물 사이의 반응을 정밀하게 예측해야 할 때는 일반화보다 정밀도가 중요하므로, 두 파이프라인의 파라미터를 독립적으로 유지하여 더 섬세하고 특화된 학습(Nuanced Learning)을 수행합니다.

4. 퓨전 이후 최종 출력으로의 결합

두 개의 멀티 헤드 어텐션 레이어를 거친 후, 피처들은 다음과 같은 흐름으로 최종 통합(Fusion)됩니다.

  1. 약물 인코딩(Molecular Feature): 두 MHA 레이어에서 나온 각각의 출력을 Element-wise Sum. 
  2. 세포 인코딩: 인터랙톰 피처와 전사체 피처는 각각 ReLU 활성화 함수가 포함된 선형 레이어를 거친 두 결과를 Element-wise Sum .
  3. 최종 결합 (Concatenation): 앞서 유도된 '분자 피처'와 '재인코딩된 세포 피처의 합'을 하나로 이어 붙입니다. 최종 결합된 피처의 차원은 [768 + 512]이 됩니다.
  4. 예측 (MLP 단계): 이 통합 피처는 점진적으로 크기가 줄어드는 완전 연결 레이어(Fully Connected Layers: [768+512] ➔ 512 ➔ 256 ➔ 128 ➔ 1)를 거쳐, 최종적으로 약물 효과를 나타내는 수치인 LN IC50(반합성 억제 농도의 자연로그 값)을 예측하게 됩니다. 이 과정은 회귀(Regression) 문제이므로 MSE(평균제곱오차) 손실 함수를 통해 학습됩니다.