DSN Bootcamp 2024: Leveraging Graph-based Techniques for Enhanced Learning in Language Models by Emeka Okoye
# Core dependencies
pip install torch transformers torch_geometric networkx pandas
pip install spacy scikit-learn
python -m spacy download en_core_web_sm
# Optional visualization
pip install matplotlib seaborn
class MedicalEntity:
def __init__(self, id, name, type):
self.id = id
self.name = name
self.type = type # medication, condition, side_effect
self.attributes = {}
class Relationship:
def __init__(self, source, target, type, weight=1.0):
self.source = source
self.target = target
self.type = type # treats, causes, contraindicates
self.weight = weight
class MedicalKnowledgeGraph:
def __init__(self):
self.graph = nx.Graph()
self.entity_embeddings = {}
def add_entity(self, entity):
self.graph.add_node(entity.id,
name=entity.name,
type=entity.type,
attributes=entity.attributes)
def add_relationship(self, relationship):
self.graph.add_edge(relationship.source,
relationship.target,
type=relationship.type,
weight=relationship.weight)
class MedicalGNN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.conv1 = GCNConv(input_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, output_dim)
self.dropout = nn.Dropout(0.2)
def forward(self, x, edge_index, edge_attr=None):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.dropout(x)
x = self.conv2(x, edge_index)
return x
class IntegrationLayer(nn.Module):
def __init__(self, llm_dim, gnn_dim, output_dim):
super().__init__()
self.attention = nn.MultiheadAttention(llm_dim, num_heads=8)
self.fusion = nn.Linear(llm_dim + gnn_dim, output_dim)
def forward(self, llm_features, graph_features):
attended_features, _ = self.attention(llm_features,
graph_features,
graph_features)
combined = torch.cat([attended_features, graph_features], dim=-1)
return self.fusion(combined)
class EntityExtractor:
def __init__(self):
self.nlp = spacy.load('en_core_web_sm')
def extract_entities(self, text):
doc = self.nlp(text)
medical_entities = []
for ent in doc.ents:
if ent.label_ in ['DRUG', 'CONDITION', 'SYMPTOM']:
medical_entities.append({
'text': ent.text,
'label': ent.label_,
'start': ent.start_char,
'end': ent.end_char
})
return medical_entities
class GraphQueryEngine:
def __init__(self, knowledge_graph):
self.kg = knowledge_graph
self.cache = {}
def query(self, entities, relation_types=None):
cache_key = f"{'-'.join(entities)}_{'-'.join(relation_types or [])}"
if cache_key in self.cache:
return self.cache[cache_key]
subgraph = self.kg.get_subgraph(entities, relation_types)
self.cache[cache_key] = subgraph
return subgraph
class ResponseTemplate:
def __init__(self):
self.templates = {
'side_effect': "Based on medical knowledge, {medication} may cause: {effects}",
'contraindication': "CAUTION: {condition} may be a contraindication for {medication}",
'interaction': "Potential interaction between {med1} and {med2}: {details}"
}
def format_response(self, template_key, **kwargs):
template = self.templates[template_key]
return template.format(**kwargs)
class SafetyChecker:
def __init__(self):
self.disclaimers = [
"Consult healthcare provider for medical advice.",
"This information is for educational purposes only."
]
def check_response(self, response):
# Add safety checks and disclaimers
checked_response = response + "\n\n"
checked_response += "IMPORTANT: " + " ".join(self.disclaimers)
return checked_response
class MedicalQASystem:
def __init__(self):
self.kg = MedicalKnowledgeGraph()
self.gnn = MedicalGNN(input_dim=768, hidden_dim=256, output_dim=128)
self.llm = AutoModel.from_pretrained('bert-base-uncased')
self.entity_extractor = EntityExtractor()
self.query_engine = GraphQueryEngine(self.kg)
self.integration_layer = IntegrationLayer(768, 128, 256)
self.response_generator = ResponseTemplate()
self.safety_checker = SafetyChecker()
def process_query(self, question):
# 1. Extract entities
entities = self.entity_extractor.extract_entities(question)
# 2. Get LLM features
llm_features = self.llm(question)
# 3. Query knowledge graph
graph_data = self.query_engine.query(entities)
# 4. Process with GNN
graph_features = self.gnn(graph_data)
# 5. Integrate information
combined_features = self.integration_layer(llm_features, graph_features)
# 6. Generate response
response = self.response_generator.format_response(
template_key='side_effect',
medication=entities[0]['text'],
effects=graph_data['side_effects']
)
# 7. Safety check
safe_response = self.safety_checker.check_response(response)
return safe_response
class PerformanceMonitor:
def __init__(self):
self.metrics = {
'response_time': [],
'accuracy': [],
'entity_extraction_success': [],
'graph_query_hits': []
}
def log_metric(self, metric_name, value):
self.metrics[metric_name].append(value)
def get_summary(self):
return {k: np.mean(v) for k, v in self.metrics.items()}
class CacheManager:
def __init__(self, max_size=1000):
self.max_size = max_size
self.cache = OrderedDict()
def get(self, key):
if key in self.cache:
self.cache.move_to_end(key)
return self.cache[key]
return None
def put(self, key, value):
if len(self.cache) >= self.max_size:
self.cache.popitem(last=False)
self.cache[key] = value
self.cache.move_to_end(key)
# Initialize the system
qa_system = MedicalQASystem()
# Process a question
question = "What are the potential side effects of metformin for a person with kidney issues?"
response = qa_system.process_query(question)
print(response)