K-means アルゴリズムは、典型的なクラスタリングアルゴリズムです。

これは、空間内の k 個の頂点を中心として使用し、それらに最も近い頂点をグループ化することによってクラスタリングを実行します。 クラスタリングの中心値は、最適なクラスタリング結果が得られるまで、反復しながら連続的に更新されます。

サンプルの集合を k個のクラスに分割するため、アルゴリズムは以下のように動作します。
  1. k 個のクラスの中心の初期値を選択します。
  2. 任意のサンプルから k 個の中心までの距離を反復 i で 計算し、サンプルを最も近い中心のクラスにグループ化します。
  3. 平均や他の方法を使ってクラスの中心値を更新します。
  4. すべての k 個のクラスター中心について、反復後に更新された値が変化しないままであるか、しきい値よりも小さい場合、反復は終了します。 そうでない場合、反復は継続します。

サンプルコード

K-means クラスタリングアルゴリズムのコードは以下のとおりです。
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;

import org.apache.logging.log4j.Logger;

import com.aliyun.odps.io.WritableRecord;
import com.aliyun.odps.graph.Aggregator;
Import com. aliyun. ODPS. graph. computercontext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.io.DoubleWritable;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.io.Text;
import com.aliyun.odps.io.Tuple;
import com.aliyun.odps.io.Writable;

public class Kmeans {
  private final static Logger LOG = Logger.getLogger(Kmeans.class);

  public static class KmeansVertex extends
  Vertex<Text, Tuple, NullWritable, NullWritable> {

    @ Override
    public void compute(
    ComputeContext<Text, Tuple, NullWritable, NullWritable> context,
    Iterable<NullWritable> messages) throws IOException {
      context.aggregate(getValue());
    }

  

  public static class KmeansVertexReader extends
  GraphLoader<Text, Tuple, NullWritable, NullWritable> {
    @Override
    public void load(LongWritable recordNum, WritableRecord record,
    MutationContext<Text, Tuple, NullWritable, NullWritable> context)
        throws IOException {
      KmeansVertex vertex = new KmeansVertex();
      vertex.setId(new Text(String.valueOf(recordNum.get())));
      vertex.setValue(new Tuple(record.getAll()));
      context.addVertexRequest(vertex);
    

  

  public static class KmeansAggrValue implements Writable {

    Tuple centers = new Tuple();
    Tuple sums = new Tuple();
    Tuple counts = new Tuple();

    @ Override
    public void write(DataOutput out) throws IOException {
      centers.write(out);
      sums.write(out);
      counts.write(out);
    

    @Override
    public void readFields(DataInput in) throws IOException {
      centers = new Tuple();
      centers.readFields(in);
      sums = new Tuple();
      sums.readFields(in);
      counts = new Tuple();
      counts.readFields(in);
    

    @Override
    public String toString(){
      return "centers " + centers.toString() + ", sums " + sums.toString()
          + ", counts " + counts.toString();
    

  

public static class KmeansAggregator extends Aggregator<KmeansAggrValue> {

    @SuppressWarnings("rawtypes")
    @Override
    public KmeansAggrValue createInitialValue(WorkerContext context)
        throws IOException {
      KmeansAggrValue aggrVal = null;
      if (context.getSuperstep() == 0) {
        aggrVal = new KmeansAggrValue();
        aggrVal.centers = new Tuple();
        aggrVal.sums = new Tuple();
        aggrVal.counts = new Tuple();

        byte[] centers = context.readCacheFile("centers");
        String lines[] = new String(centers).split("\n");

for (int i = 0; i < lines.length; i++) {
          String[] ss = lines[i].split(",");
          Tuple center = new Tuple();
          Tuple sum = new Tuple();
          for (int j = 0; j < ss.length; ++j) {
            center.append(new DoubleWritable(Double.valueOf(ss[j].trim())));
            sum.append(new DoubleWritable(0.0));
          
          LongWritable count = new LongWritable(0);
          aggrVal.sums.append(sum);
          aggrVal.counts.append(count);
          aggrVal.centers.append(center);
        
      } else{
        aggrVal = (KmeansAggrValue) context.getLastAggregatedValue(0);
      

      return aggrVal;
    

    @Override
    Public void aggregate (glasvalue, object item ){
      int min = 0;
      double mindist = Double.MAX_VALUE;
      Tuple point = (Tuple) item;

for (int i = 0; i < value.centers.size(); i++) {
        Tuple center = (Tuple) value.centers.get(i);
        // use Euclidean Distance, no need to calculate sqrt
        double dist = 0.0d;
        for (int j = 0; j < center.size(); j++) {
          double v = ((DoubleWritable) point.get(j)).get()
              - ((DoubleWritable) center.get(j)).get();
          dist += v * v;
        
        if (dist < mindist) {
          mindist = dist;
          min = i;
        
      

      // update sum and count
      Tuple sum = (Tuple) value.sums.get(min);
      for (int i = 0; i < point.size(); i++) {
        DoubleWritable s = (DoubleWritable) sum.get(i);
        s.set(s.get() + ((DoubleWritable) point.get(i)).get());
      
      LongWritable count = (LongWritable) value.counts.get(min);
      count.set(count.get() + 1);
    

    @Override
    public void merge(KmeansAggrValue value, KmeansAggrValue partial) {
    for (int i = 0; i < value.sums.size(); i++) {
        Tuple sum = (Tuple) value.sums.get(i);
        Tuple that = (Tuple) partial.sums.get(i);
        for (int j = 0; j < sum.size(); j++) {
          DoubleWritable s = (DoubleWritable) sum.get(j);
          s.set(s.get() + ((DoubleWritable) that.get(j)).get());
        
      

for (int i = 0; i < value.counts.size(); i++) {
        LongWritable count = (LongWritable) value.counts.get(i);
        count.set(count.get() + ((LongWritable) partial.counts.get(i)).get());
      
    

    @SuppressWarnings("rawtypes")
    @Override
    public boolean terminate(WorkerContext context, KmeansAggrValue value)
        throws IOException {

      // compute new centers
      Tuple newCenters = new Tuple(value.sums.size());
      for (int i = 0; i < value.sums.size(); i++) {
        Tuple sum = (Tuple) value.sums.get(i);
        Tuple newCenter = new Tuple(sum.size());
        LongWritable c = (LongWritable) value.counts.get(i);
        for (int j = 0; j < sum.size(); j++) {

          DoubleWritable s = (DoubleWritable) sum.get(j);
          double val = s.get() / c.get();
          newCenter.set(j, new DoubleWritable(val));

          // reset sum for next iteration
          s.set(0.0d);
        
        // reset count for next iteration
        c.set(0);
        newCenters.set(i, newCenter);
      

      // update centers
      Tuple oldCenters = value.centers;
      value.centers = newCenters;

      LOG.info("old centers: " + oldCenters + ", new centers: " + newCenters);

      // compare new/old centers
      boolean converged = true;
      for (int i = 0; i < value.centers.size() && converged; i++) {
        Tuple oldCenter = (Tuple) oldCenters.get(i);
        Tuple newCenter = (Tuple) newCenters.get(i);
        double sum = 0.0d;
        for (int j = 0; j < newCenter.size(); j++) {
          double v = ((DoubleWritable) newCenter.get(j)).get()
              - ((DoubleWritable) oldCenter.get(j)).get();
          sum += v * v;
        
        double dist = Math.sqrt(sum);
        LOG.info("old center: " + oldCenter + ", new center: " + newCenter
            + ", dist: " + dist);
        // converge threshold for each center: 0.05
        converged = dist < 0.05d;
      

      if (converged || context.getSuperstep() == context.getMaxIteration() - 1) {
        // converged or reach max iteration, output centers
        for (int i = 0; i < value.centers.size(); i++) {
          context.write(((Tuple) value.centers.get(i)).toArray());
        
        // true means to terminate iteration
        return true;
      

      // false means to continue iteration
      return false;
    
  

  private static void printUsage() {
  System. out. println ("Usage: <in> <out> [Max iterations (default 30)] ");
    System.exit(-1);
  

  public static void main(final String [] args)throws IOException{
  if (args.length < 2)
      printUsage();

    GraphJob job = new GraphJob();

    job.setGraphLoaderClass(KmeansVertexReader.class);
    job.setRuntimePartitioning(false);
    job.setVertexClass(KmeansVertex.class);
    job.setAggregatorClass(KmeansAggregator.class);
    job.addInput(TableInfo.builder().tableName(args[0]).build());
    job.addOutput(TableInfo.builder().tableName(args[1]).build());

    // default max iteration is 30
    job.setMaxIteration(30);
    if (args.length >= 3)
      job.setMaxIteration(Integer.parseInt(args[2]));

    long start = System.currentTimeMillis();
    job.run();
    System.out.println("Job Finished in "
        + (System.currentTimeMillis() - start) / 1000.0 + " seconds");
  

以下は、K-means のソースコードについての説明です。
  • 26 行目: KmeansVertex を定義します。 compute () メソッドの実装はシンプルです。 コンテキストオブジェクトの aggregate () メソッドを呼び出します。 次に、現在の頂点の値 (タプル 型で、ベクトルで表現される) を送信します。
  • 38 行目: KmeansVertexReader クラスを定義し、グラフを読み込み、テーブル内の各レコードを頂点と見なします。 頂点 ID は関係ありません。送信された recordNum は ID として使用されます。 頂点の値は、レコードのすべての列から構成されるタプルです。
  • 83 行目: KmeansAggregator を定義します。 このクラスは、K-means アルゴリズムの主なロジックをカプセル化します。ここで、
    • createInitialValue は、反復ごとに初期値を作成します (k クラス 中心点)。 最初の反復 (superstep が 0) では、値は中心点の初期値です。 それ以外の場合は、値は最後の反復が終了したときの新しい中心点です。
    • aggregate () メソッドは、各頂点から異なるクラスの中心までの距離を計算し、最も近い中心のクラスとしてその頂点を分類し、そのクラスの合計とカウントを更新します。
    • merge () メソッドは、各 Worker によって収集された合計とカウントを組み合わせます。
    • terminate () メソッドは、各クラスの合計とカウントに基づいて新しい中心点を計算します。 新しい中心点と古い中心点の間の距離がしきい値より小さいか、または反復回数が上限値に達すると、反復は終了します (false が返されます)。 最終的な中心点が結果テーブルに書き込まれます。
  • 236 行目: メインプログラム (main 関数) を実行し、GraphJob を定義し、そして Vertex/GraphLoader/Aggregator の実装を指定します。 最大反復数 (デフォルトは 30)、および入力テーブルと出力テーブル。
  • 243 行目: job.setRuntimePartitioning (false) を指定します。 K-means アルゴリズムでは、グラフの読み込み中に頂点を分散させる必要はありません。 RuntimePartitioning が false に設定されると、グラフ読み込みのパフォーマンスが向上します。