Added probability table examples.
[simdecs.git] / src / java / org / ufcspa / simdecs / diagram / bn / BayesianNetwork.java
1 /*
2  * To change this template, choose Tools | Templates
3  * and open the template in the editor.
4  */
5 package org.ufcspa.simdecs.diagram.bn;
6
7 import java.io.File;
8 import java.util.ArrayList;
9 import java.util.Vector;
10
11 import unbbayes.io.xmlbif.version6.*;
12 import unbbayes.prs.Node;
13 import unbbayes.prs.bn.JunctionTreeAlgorithm;
14 import unbbayes.prs.bn.ProbabilisticNetwork;
15 import unbbayes.prs.bn.ProbabilisticNode;
16
17 import unbbayes.simdecs.PerguntaNodo;
18
19 /**
20  *
21  * @author mchelem
22  */
23 public class BayesianNetwork {
24             
25     public static ProbabilisticNetwork loadNetwork(String filename) throws Exception {
26         ProbabilisticNetwork bayesianNetwork = new ProbabilisticNetwork(null);
27         XMLBIFIO.loadXML(new File(filename),bayesianNetwork);        
28         return bayesianNetwork;
29     } 
30     
31     /* Testing */
32     public static void main(String[] args) throws Exception {
33         ProbabilisticNetwork bn = BayesianNetwork.loadNetwork("samples/headache.xml");
34         
35         // Get bayesian network name
36         System.out.println("Network name: "+ bn.getName());
37         
38         // Get all the nodes
39         System.out.println("All nodes: ");
40         ArrayList<Node> nodes = bn.getNodes();
41         for (Node node: nodes){
42             System.out.println("-> " + node.getName());            
43         }
44         
45         // Get node by name and its children
46         // To work with probabilistic table, node must be a Probabilistic node
47         ProbabilisticNode facialPainNode = (ProbabilisticNode)bn.getNode("facial_pain");
48         System.out.println("\nNode: " + facialPainNode.getName());
49         
50         for (int i = 0; i < facialPainNode.getStatesSize(); i++) {
51              System.out.println(facialPainNode.getStateAt(i)+":"+facialPainNode.getMarginalAt(i));
52         }
53         
54         // Fields custo, tempo and pergunta moved to database. 
55         // Included here for testing.
56         System.out.println("Custo: " + facialPainNode.getCustoEtapa());
57         System.out.println("Tempo: " + facialPainNode.getTempoEtapa());
58            
59         Vector<PerguntaNodo> perguntas = facialPainNode.getPerguntas();
60         for (PerguntaNodo pergunta: perguntas) {
61             // Setting questions because it is not set in the saved net
62             pergunta.setPergunta("Minha pergunta?");
63             System.out.println("Pergunta: " + pergunta.getPergunta());  
64          }
65                 
66         ArrayList<Node> childrenNodes = facialPainNode.getChildren();
67         System.out.println("Children:");
68         for (Node node: childrenNodes){
69             System.out.println("-> " + node.getName());            
70         }
71         
72         // Updating nodes
73         // Before updating probabilities, it is required to run the network
74         System.out.println("Running Bayesian network:");
75
76         JunctionTreeAlgorithm jt = new JunctionTreeAlgorithm();
77         jt.setNet(bn);      
78         jt.run();
79         
80         ProbabilisticNode holocranialPainNode = (ProbabilisticNode) bn.getNode("holocranial_pain");
81         System.out.println("\nNode: " + holocranialPainNode.getName());
82
83         System.out.println("Before updating probability table:");
84         for (int i = 0; i < holocranialPainNode.getStatesSize(); i++) {
85              System.out.println(holocranialPainNode.getStateAt(i)+":"+holocranialPainNode.getMarginalAt(i));
86         }
87                 
88         float likelihood[] = new float[holocranialPainNode.getStatesSize()]; 
89         likelihood[0] = 1.0f;
90         likelihood[1] = 0.0f;
91         holocranialPainNode.addLikeliHood(likelihood);
92         
93         bn.updateEvidences();
94         
95         System.out.println("After updating probability table:");        
96         for (int i = 0; i < holocranialPainNode.getStatesSize(); i++) {
97              System.out.println(holocranialPainNode.getStateAt(i)+":"+holocranialPainNode.getMarginalAt(i));
98         }
99     }
100     
101     
102 }
103