Écrire le programme d'évaluation dans un tJava - 6.5

Machine Learning

author
Talend Documentation Team
EnrichVersion
6.5
EnrichProdName
Talend Big Data
Talend Big Data Platform
Talend Data Fabric
Talend Real-Time Big Data Platform
task
Création et développement > Systèmes tiers > Composants Machine Learning
Gouvernance de données > Systèmes tiers > Composants Machine Learning
Qualité et préparation de données > Systèmes tiers > Composants Machine Learning
EnrichPlatform
Studio Talend

Procédure

  1. Double-cliquez sur le tJava pour ouvrir sa vue Component.
  2. Cliquez sur le bouton Sync columns pour vous assurer que le tJava récupère le schéma répliqué du tClassify.
  3. Cliquez sur l'onglet Advanced settings pour ouvrir sa vue.
  4. Dans le champ Classes, saisissez le code pour définir les classes Java à utiliser afin de vérifier si les libellés de classe prédits correspondent aux libellés actuels des classes (spam pour les messages indésirables et ham pour les messages normaux). Dans ce scénario, row7 est l'ID de la connexion entre le tClassify et le tReplicate et contient les résultats de classification à envoyer aux composants suivants. row7Struct est la classe Java du RDD pour les résultats de classification. Dans votre code, vous devez remplacer row7, utilisé seul ou au sein de row7Struct, par l'ID de la connexion utilisée dans votre Job.
    Les noms des colonnes, comme reallabel ou label ont été définis dans l'étape précédente lors de la configuration des différents composants. Si vous les avez nommées différemment, gardez la cohérence pour utilisation dans votre code.
    public static class SpamFilterFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		
    		return row7.reallabel.equals("spam");
    	}
    	
    }
    
    // 'negative': ham
    // 'positive': spam
    // 'false' means the real label & predicted label are different 
    // 'true' means the real label & predicted label are the same
    
    public static class TrueNegativeFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		
    		return (row7.label.equals("ham") && row7.reallabel.equals("ham"));
    	}
    	
    }
    
    public static class TruePositiveFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		// true positive cases
    		return (row7.label.equals("spam") && row7.reallabel.equals("spam"));
    	}
    	
    }
    
    public static class FalseNegativeFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		// false positive cases
    		return (row7.label.equals("spam") && row7.reallabel.equals("ham"));
    	}
    	
    }
    
    public static class FalsePositiveFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		// false positive cases
    		return (row7.label.equals("ham") && row7.reallabel.equals("spam"));
    	}
    	
    }
  5. Cliquez sur l'onglet Basic settings pour ouvrir sa vue et, dans le champ Code, saisissez le code à utiliser pour calculer le score de précision et le coefficient de corrélation de Matthews (Matthews Correlation Coefficient, MCC) du modèle de classification.
    Pour une explication générale relative à ce coefficient, consultez l'article Wikipédia https://en.wikipedia.org/wiki/Matthews_correlation_coefficient (en anglais).
    long nbTotal = rdd_tJava_1.count();
    
    long nbSpam = rdd_tJava_1.filter(new SpamFilterFunction()).count();
    
    long nbHam = nbTotal - nbSpam;
    
    // 'negative': ham
    // 'positive': spam
    // 'false' means the real label & predicted label are different 
    // 'true' means the real label & predicted label are the same
    
    long tn = rdd_tJava_1.filter(new TrueNegativeFunction()).count();
    
    long tp = rdd_tJava_1.filter(new TruePositiveFunction()).count();
    
    long fn = rdd_tJava_1.filter(new FalseNegativeFunction()).count();
    
    long fp = rdd_tJava_1.filter(new FalsePositiveFunction()).count();
    
    double mmc = (double)(tp*tn -fp*fn) / java.lang.Math.sqrt((double)((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)));
    
    System.out.println("Accuracy:"+((double)(tp+tn)/(double)nbTotal));
    System.out.println("Spams caught (SC):"+((double)tp/(double)nbSpam));
    System.out.println("Blocked hams (BH):"+((double)fp/(double)nbHam));
    System.out.println("Matthews correlation coefficient (MCC):" + mmc);