While working on a classification task recently, I started out with the well-known Naive Bayes classifier but soon realized that I wasn't getting the classification performance I needed. At the same time, I was keenly aware that many of the feature variables carry a causal relationship between them that the classifier wasn't exploiting.
I needed another classifier with the simplicity and run-time performance of Naive Bayes so I started looking into Tree-augmented Naive Bayes (TAN). My search for a Python library supporting TAN eventually led me to pgmpy.
Pgmpy is a Python Library for learning (Structure and Parameter) and inference (Statistical and Causal) in Bayesian Networks.
Pgmpy is versatile with capabilities beyond what my task needed. It also has some structure learning algorithms but not TAN so I thought to contribute one to the library. In this article, I'll give a tutorial on how to use TAN in pgmpy.
For demonstration purpose, I'll generate sample data from a handcrafted Bayesian Network (BN) model. With the generated data, I'll train a Naive Bayes classifier and a TAN classifier, and compare their prediction performance. Our BN graph is illustrated below. Node C is our class variable while node R, S, T, U and V are the feature variables. All variables are discrete.
Generate Sample Data
First, import all the packages we will be needing:
import pandas as pd
import numpy as np
from pgmpy.models.BayesianModel import BayesianModel
from pgmpy.factors.discrete import TabularCPD
from pgmpy.sampling import BayesianModelSampling
from pgmpy.estimators import TreeSearch, BayesianEstimator
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.naive_bayes import MultinomialNB
To encode our BN graph, we instantiate BayesianModel with the directed edges of the graph:
# construct naive bayes graph and add interaction between the features
model = BayesianModel([('C', 'R'), ('C', 'S'), ('C', 'T'), ('C', 'U'), ('C', 'V'),('R', 'S'), ('R', 'T'), ('R', 'U'), ('R', 'V')])
Next, we parameterize the graph by defining the conditional probabilities for each node. For more information on parameterization, check out the pgmpy tutorials.
# add conditional probability distribution to edgescpd_c = TabularCPD('C', 2, [[0.5], [0.5]])cpd_r = TabularCPD('R', 3, [[0.6,0.2],[0.3,0.5],[0.1,0.3]], evidence=['C'],evidence_card=[2])cpd_s = TabularCPD('S', 3, [[0.1,0.1,0.2,0.2,0.7,0.1],[0.1,0.3,0.1,0.2,0.1,0.2],[0.8,0.6,0.7,0.6,0.2,0.7]],evidence=['C','R'], evidence_card=[2,3])cpd_t = TabularCPD('T', 2, [[0.7,0.2,0.2,0.5,0.1,0.3],[0.3,0.8,0.8,0.5,0.9,0.7]],evidence=['C','R'], evidence_card=[2,3])cpd_u = TabularCPD('U', 3, [[0.3,0.8,0.2,0.8,0.4,0.7],[0.4,0.1,0.4,0.1,0.1,0.1],[0.3,0.1,0.4,0.1,0.5,0.2]],evidence=['C','R'], evidence_card=[2,3])cpd_v = TabularCPD('V', 2, [[0.5,0.6,0.6,0.5,0.5,0.4],[0.5,0.4,0.4,0.5,0.5,0.6]],evidence=['C','R'], evidence_card=[2,3])model.add_cpds(cpd_c, cpd_r, cpd_s, cpd_t, cpd_u, cpd_v)
With our BN model defined, we can now generate sample data:
# generate sample data from our BN modelinference = BayesianModelSampling(model)df_data = inference.forward_sample(size=30000, return_type='dataframe')
We'll be using the data to train Naive Bayes and TAN classifier to see which one performs better. Before that, we split the data into training and test set.
# split data into training and test setcols_features = df_data.columns.tolist()cols_features.remove('C')X = df_data[cols_features].valuesy = df_data[['C']].valuesX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=21, stratify=y)df_train = pd.DataFrame(np.concatenate((y_train, X_train), axis=1), columns=df_data.columns)
Predicting with Naive Bayes
Now, we are ready to train our Naive Bayes model using the training data and then measure classification performance with the test set. I'll use the classifier from sklearn but pgmpy also has a Naive Bayes classifier as well shall you choose.
# train naive bayes classifier and predictmodel_nb = MultinomialNB().fit(X_train, y_train)y_pred = model_nb.predict(X_test)print(classification_report(y_test, y_pred))
In short, Naive Bayes precision and recall are both 72%.
Predicting with TAN
Next up is TAN. First, let's learn the graph structure from the training data. To capture the interaction between the feature variables, TAN casts a tree structure over them. So, you need to pass in the "root_node" parameter. To be fair, we pretend that we don't know what the real graph looks like, so we pick a random node as the root node:
# learn the TAN graph structure from dataest = TreeSearch(df_train, root_node='U')dag = est.estimate(estimator_type='tan', class_node='C')
Before we can make predictions, we also need to parameterize the model using the training data.
# construct Bayesian network by parameterizing the graph structuremodel = BayesianModel(dag.edges())model.fit(df_train, estimator=BayesianEstimator, prior_type='K2')
Now we're ready to classify the test data using the TAN model and measure the performance:
# draw inference from BNX_test_df = pd.DataFrame(X_test, columns=cols_features)y_pred = model.predict(X_test_df).valuesprint(classification_report(y_test, y_pred))
Our TAN model precision and recall are both 81% compare to 72% for Naive Bayes. That's quite a significant improvement!
Predicting with Incomplete Data
As a bonus, now let's try to make predictions on incomplete data. We will ask pgmpy to predict the class when only feature S, T and U are given.
from pgmpy.inference import VariableElimination# predict the class given some of the features
infer = VariableElimination(model)
query = infer.query(variables=['C'], evidence={'S': 0, 'T': 1, 'U': 2})print(query)
In this case, pgmpy returns the chance of C=1 is 93%.
Instead of predicting the class, let's predict what feature R is likely to be given class C=1 and feature S=1:
# predict the feature given the class and another featurequery = infer.query(variables=['R'], evidence={'C': 1, 'S': 1})print(query)
In this case, pgmpy returns the chance of R=0 is 27%, R=1 is 32% and R=2 is 41%.
Pretty cool, isn't it? Like I said in the beginning of the article, pgmpy is versatile and I'm only scratching the surface of what it can do.