0f2597ccdb6f205ebae382d7950031da4b8e42ee
[simdecs2.git] / src / java / org / ufcspa / simdecs / util / UnBBayesUtil.java
1 /*
2  * To change this template, choose Tools | Templates
3  * and open the template in the editor.
4  */
5 package org.ufcspa.simdecs.util;
6
7 import java.io.File;
8 import java.security.InvalidParameterException;
9 import java.util.*;
10 import javax.persistence.EntityManager;
11 import org.ufcspa.simdecs.entities.Nodo;
12 import org.ufcspa.simdecs.entities.NodoPaciente;
13 import org.ufcspa.simdecs.entities.Paciente;
14 import org.ufcspa.simdecs.entities.Rede;
15 import org.ufcspa.simdecs.exceptions.InvalidNodeState;
16 import unbbayes.io.xmlbif.version6.XMLBIFIO;
17 import unbbayes.prs.Node;
18 import unbbayes.prs.bn.JunctionTreeAlgorithm;
19 import unbbayes.prs.bn.ProbabilisticNetwork;
20 import unbbayes.prs.bn.ProbabilisticNode;
21 import unbbayes.prs.bn.TreeVariable;
22
23 /**
24  *
25  * @author maroni
26  */
27 public class UnBBayesUtil {
28
29     
30     private ProbabilisticNetwork rede;
31
32     public UnBBayesUtil(Rede redeEntity) throws Exception {
33         abreRede(redeEntity);
34     }
35         
36     public final void abreRede(Rede redeEntity) throws Exception {
37         // required to run unbbayes gui classes on server
38         System.setProperty("java.awt.headless", "false"); 
39         rede = new ProbabilisticNetwork("rede"+redeEntity.getId());
40         
41         XMLBIFIO.loadXML(new File(redeEntity.getArquivo()), rede);
42         
43         compilar();
44     }
45
46     public void compilar() {
47         JunctionTreeAlgorithm jt = new JunctionTreeAlgorithm();
48         jt.setNet(rede);            
49         jt.run();
50     }
51     
52     
53     public Node getNodeByName(String nodeName) {
54         for(Node node : rede.getNodes()) {
55             if (node.getName().toLowerCase().equals(nodeName.toLowerCase()))
56                 return node;
57         }
58         
59         return null;
60     }
61     
62     
63     public void setNodeState(ProbabilisticNode node, Object ... stateProbs) throws Exception {
64         HashMap<String, Float> states = new HashMap<String, Float>();
65         String state = null;
66         Float  probability = null;
67         
68         for (Object stateProb : stateProbs) {
69             if (state == null)
70                 if (stateProb instanceof String)
71                     state = (String) stateProb;
72                 else
73                     throw new InvalidParameterException("Os parametros devem ser informados no formato: state(string), prob(float), ..., ...");
74             else if (state != null)
75                 if ( !(stateProb instanceof String) ) {
76                     if (stateProb instanceof Integer)
77                         probability =  new Float ((Integer) stateProb);
78                     else if (stateProb instanceof Float)
79                         probability =  (Float) stateProb;
80                     else
81                         throw new InvalidParameterException("Os parametros devem ser informados no formato: state(string), prob(float), ..., ...");
82                         
83                     states.put(state.toLowerCase(), probability);
84                     probability = null;
85                     state = null;
86                 } else
87                     throw new InvalidParameterException("Os parametros devem ser informados no formato: state(string), prob(float), ..., ...");
88         }
89                 
90         float likelihood[] = new float[node.getStatesSize()]; //nr. de estados do nodo
91         for(int i=0; i < node.getStatesSize(); i++) {
92             Float prob = states.get(node.getStateAt(i).toLowerCase());
93             states.remove(node.getStateAt(i).toLowerCase());
94             if (prob != null)
95                 likelihood[i] = prob;
96             else
97                 likelihood[i] = 0;                
98         }
99
100         if (states.size() > 0) {
101             StringBuilder invalidStates = new StringBuilder();
102             for (String key : states.keySet()) {
103                 invalidStates.append(key);
104                 invalidStates.append(" ");
105             }
106             throw new InvalidNodeState("Estados de nodos inválidos: " + invalidStates);
107         }
108
109         node.addLikeliHood(likelihood);
110         rede.updateEvidences();            
111     }
112     
113     public List<ProbabilisticBean> getProbabilidades(TreeVariable node) {
114                 
115         List<ProbabilisticBean> listProbs = new ArrayList<ProbabilisticBean>();
116         
117         for (int i=0; i < node.getStatesSize(); i++)
118             listProbs.add(new ProbabilisticBean(node.getStateAt(i), node.getMarginalAt(i)));
119
120         return listProbs;
121     }
122     
123     public float getProbabilidadeByState(TreeVariable node, String state) {
124         for (ProbabilisticBean bean : getProbabilidades(node)) {
125             if (state.toLowerCase().equals(bean.getState().toLowerCase()))
126                 return bean.getProbability();
127         }
128         
129         return 0;
130     }
131
132     public List<ProbabilisticBean> getProbabilidadesAscOrder(TreeVariable node) {
133         List<ProbabilisticBean> listProbs = getProbabilidades(node);
134         
135         Collections.sort(listProbs, new Comparator() {
136             @Override
137             public int compare(Object o1, Object o2) {
138                 ProbabilisticBean nodo1 = (ProbabilisticBean) o1;
139                 ProbabilisticBean nodo2 = (ProbabilisticBean) o2;
140                 
141                 if (nodo1.getProbability() > nodo2.getProbability())
142                     return 1;
143                 else if (nodo1.getProbability() < nodo2.getProbability())
144                     return -1;
145                 else
146                     return 0;
147             }
148         });        
149         
150         return listProbs;
151     }
152     
153     public List<ProbabilisticBean> getProbabilidadesDescOrder(TreeVariable node) {
154         List<ProbabilisticBean> listProbs = getProbabilidades(node);
155         
156         Collections.sort(listProbs, new Comparator() {
157             @Override
158             public int compare(Object o1, Object o2) {
159                 ProbabilisticBean nodo1 = (ProbabilisticBean) o1;
160                 ProbabilisticBean nodo2 = (ProbabilisticBean) o2;
161                 
162                 if (nodo1.getProbability() > nodo2.getProbability())
163                     return -1;
164                 else if (nodo1.getProbability() < nodo2.getProbability())
165                     return 1;
166                 else
167                     return 0;
168             }
169         });        
170         
171         return listProbs;
172     }
173     
174     public static UnBBayesUtil getRedePaciente(EntityManager em, Paciente paciente) throws Exception {
175         
176         UnBBayesUtil unBUtil = new UnBBayesUtil(paciente.getRede());
177
178         List<NodoPaciente> nodosPaciente = em.createQuery("From NodoPaciente np Where np.paciente.id=:pPaciente")
179                                              .setParameter("pPaciente", paciente.getId())
180                                              .getResultList();
181         
182         for(NodoPaciente nodo : nodosPaciente) {
183             ProbabilisticNode pn = (ProbabilisticNode) unBUtil.getNodeByName(nodo.getNodo().getNome());
184             if (pn != null)
185                 unBUtil.setNodeState(pn, "Yes", 1, "Not", 0);
186         }
187         
188         return unBUtil;
189     }
190     
191     public List<NodoBean> getProbabilidades(EntityManager em, String tipoNodo) {
192         Iterator<Nodo> itNodos = em.createNamedQuery("Nodo.getByTipo")
193                                    .setParameter("pTipo", tipoNodo)
194                                    .getResultList().iterator();
195         
196         ArrayList nodos = new ArrayList<NodoBean>();
197
198         while(itNodos.hasNext()) {
199             Nodo nodo = itNodos.next();
200             NodoBean nodoBean = new NodoBean(nodo, (Float) getProbabilidadeByState((TreeVariable) getNodeByName(nodo.getNome()), "Yes"));
201             nodos.add(nodoBean);
202         }
203         
204         Collections.sort(nodos, new Comparator() {
205             @Override
206             public int compare(Object o1, Object o2) {
207                 NodoBean nodo1 = (NodoBean) o1;
208                 NodoBean nodo2 = (NodoBean) o2;
209                 
210                 if (nodo1.getProbabilidade() > nodo2.getProbabilidade())
211                     return -1;
212                 else if (nodo1.getProbabilidade() < nodo2.getProbabilidade())
213                     return 1;
214                 else
215                     return 0;
216             }
217         });
218         
219         return nodos;
220     }
221
222     public Nodo getPrimeiroDiagnostico(EntityManager em) {
223         List<NodoBean> diagnosticos = getProbabilidades(em, Nodo.DIAGNOSTICO);
224         for(NodoBean nodoBean : diagnosticos) {
225             return nodoBean.getNodo();
226         }
227         
228         return null;
229     }
230
231     public Nodo getSegundoDiagnostico(EntityManager em) {
232         List<NodoBean> diagnosticos = getProbabilidades(em, Nodo.DIAGNOSTICO);
233         
234         int d=1;
235         float probabilidadeDiag1=0;
236         float probabilidadeDiag2=0;
237         Nodo segundoDiagnostico=null;
238         for(NodoBean nodoBean : diagnosticos) {
239             if (d == 1)
240                 probabilidadeDiag1 = nodoBean.getPercentualProbabilidade();            
241             if (d == 2) {
242                 probabilidadeDiag2 = nodoBean.getPercentualProbabilidade();
243                 segundoDiagnostico = nodoBean.getNodo();
244             }
245
246             if (d++ == 2)
247                 break;
248         }
249         
250         float diferenca = probabilidadeDiag1 - probabilidadeDiag2;
251         if (diferenca <= 25)
252             return segundoDiagnostico;
253         else
254             return null;
255     }
256     
257     public List<Nodo> getCondutas(EntityManager em) {
258         List<NodoBean> condutas = getProbabilidades(em, Nodo.CONDUTA);
259         List<Nodo> condutasRetornar = new ArrayList<Nodo>();
260         for(NodoBean conduta : condutas) {
261             if (conduta.getPercentualProbabilidade() < 50)
262                 continue;
263             
264             condutasRetornar.add(conduta.getNodo());
265 //System.out.println("CONDUTA*: " + conduta.getNodo().getNomeAmigavel());
266         }
267         
268         return condutasRetornar;
269         
270     }
271 }