tJavaで評価プログラムを作成する - Cloud - 8.0

Machine Learning

Version
Cloud
8.0
Language
日本語 (日本)
Product
Talend Big Data
Talend Big Data Platform
Talend Data Fabric
Talend Real-Time Big Data Platform
Module
Talend Studio
Content
ジョブデザインと開発 > サードパーティーシステム > 機械学習コンポーネント
データガバナンス > サードパーティーシステム > 機械学習コンポーネント
データクオリティとプレパレーション > サードパーティーシステム > 機械学習コンポーネント

手順

  1. tJavaをダブルクリックして、その[Component] (コンポーネント)ビューを開きます。
  2. tJavatClassifyの複製されたスキーマを確実に取得するために、[Sync columns] (カラムを同期)をクリックします。
  3. [Advanced settings] (詳細設定)タブをクリックして、ビューを開きます。
  4. [Classes] (クラス)フィールドにコードを入力して、予測されたクラスラベルが実際のクラスラベルと一致するかどうかを確認するために使うJavaクラスを定義します(ジャンクメッセージにはspam、通常のメッセージにはham)。このシナリオでは、row7tClassifytReplicateの間の接続のIDであり、後続のコンポーネントに送信される分類結果を保持します。また、row7Structは分類結果のRDDのJavaクラスです。コードに含まれているrow7は、単独で使うかrow7Struct内で使うかに関係なく、ジョブに使われている対応する接続IDに置き換える必要があります。
    reallabellabelなどのカラム名は、さまざまなコンポーネントを設定した前のステップで定義済したものです。これらに異なる名前を付けた場合は、コードで使うために一貫性のある状態に保つ必要があります。
    public static class SpamFilterFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		
    		return row7.reallabel.equals("spam");
    	}
    	
    }
    
    // 'negative': ham
    // 'positive': spam
    // 'false' means the real label & predicted label are different 
    // 'true' means the real label & predicted label are the same
    
    public static class TrueNegativeFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		
    		return (row7.label.equals("ham") && row7.reallabel.equals("ham"));
    	}
    	
    }
    
    public static class TruePositiveFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		// true positive cases
    		return (row7.label.equals("spam") && row7.reallabel.equals("spam"));
    	}
    	
    }
    
    public static class FalseNegativeFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		// false positive cases
    		return (row7.label.equals("spam") && row7.reallabel.equals("ham"));
    	}
    	
    }
    
    public static class FalsePositiveFunction implements 
    	org.apache.spark.api.java.function.Function<row7Struct, Boolean>{
    	private static final long serialVersionUID = 1L;
    	@Override
    	public Boolean call(row7Struct row7) throws Exception {
    		// false positive cases
    		return (row7.label.equals("ham") && row7.reallabel.equals("spam"));
    	}
    	
    }
  5. [Basic settings] (基本設定)タブをクリックしてビューを開き、[Code] (コード)フィールドに、分類モデルの精度スコアとMatthewsコリレーション係数(MCC)の計算に使うコードを入力します。
    Mathewsコリレーション係数に関する一般的な説明は、Wikipediaのhttps://en.wikipedia.org/wiki/Matthews_correlation_coefficientを参照してください。
    long nbTotal = rdd_tJava_1.count();
    
    long nbSpam = rdd_tJava_1.filter(new SpamFilterFunction()).count();
    
    long nbHam = nbTotal - nbSpam;
    
    // 'negative': ham
    // 'positive': spam
    // 'false' means the real label & predicted label are different 
    // 'true' means the real label & predicted label are the same
    
    long tn = rdd_tJava_1.filter(new TrueNegativeFunction()).count();
    
    long tp = rdd_tJava_1.filter(new TruePositiveFunction()).count();
    
    long fn = rdd_tJava_1.filter(new FalseNegativeFunction()).count();
    
    long fp = rdd_tJava_1.filter(new FalsePositiveFunction()).count();
    
    double mmc = (double)(tp*tn -fp*fn) / java.lang.Math.sqrt((double)((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)));
    
    System.out.println("Accuracy:"+((double)(tp+tn)/(double)nbTotal));
    System.out.println("Spams caught (SC):"+((double)tp/(double)nbSpam));
    System.out.println("Blocked hams (BH):"+((double)fp/(double)nbHam));
    System.out.println("Matthews correlation coefficient (MCC):" + mmc);