Spark 实现kmeans算法 Posted on 2018-11-11 | In 分布式模型与编程 | 阅读数 次 spark 实现K-means算法123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156package kmeans;import java.io.BufferedReader;import java.io.File;import java.io.FileInputStream;import java.io.FileNotFoundException;import java.io.IOException;import java.io.InputStreamReader;import java.util.ArrayList;.import java.util.Arrays;import java.util.Iterator;import org.apache.spark.SparkConf;import org.apache.spark.api.java.JavaPairRDD;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import org.apache.spark.api.java.function.FlatMapFunction;import org.apache.spark.api.java.function.Function2;import org.apache.spark.api.java.function.Function;import org.apache.spark.api.java.function.PairFunction;import org.apache.spark.api.java.function.VoidFunction;import scala.Tuple2;public class kmeans{ static double[][] center = new double[4][2]; //这里有4个中心点,为2维 static int[] number = new int[4]; //记录属于当前中心点的数据的个数,方便做除法 static double[][] new_center = new double[4][2]; //计算出来的新中心点 public static void main(String[] args) { // 从文件中读出中心点,并且放入center数组中 ArrayList<String> arrayList = new ArrayList<String>(); try { File file = new File("/usr/local/hadoop-2.7.3/centers.txt"); InputStreamReader input = new InputStreamReader(new FileInputStream(file)); BufferedReader bf = new BufferedReader(input); // 按行读取字符串 String str; while ((str = bf.readLine()) != null) { arrayList.add(str); } bf.close(); input.close(); } catch (IOException e) { e.printStackTrace(); } // 对ArrayList中存储的字符串进行处理 for (int i = 0; i < 4; i++) { for (int j = 0; j < 2; j++) { String s = arrayList.get(i).split(",")[j]; center[i][j] = Double.parseDouble(s); } } //System.out.println("center+++" + center[3][1]); SparkConf conf = new SparkConf().setAppName("kmeans").setMaster("local[*]"); JavaSparkContext jsc = new JavaSparkContext(conf); JavaRDD<String> datas = jsc.textFile("spark/input4/k-means.dat"); //从hdfs上读取data while(true) { for (int i = 0; i< 4;i++) //注意每次循环都需要将number[i]变为0 { number[i]=0; } //将data分开,得到key: 属于某个中心点的序号(0/1/2/3),value: 与该中心点的距离 JavaPairRDD<Integer, Tuple2<Double, Double>> data = datas.mapToPair(new PairFunction<String, Integer,Tuple2<Double, Double>>() { private static final long serialVersionUID = 1L; @Override public Tuple2<Integer,Tuple2<Double, Double>> call(String str) throws Exception { final double[][] loc = center; String[] datasplit = str.split(","); double x = Double.parseDouble(datasplit[0]); double y = Double.parseDouble(datasplit[1]); double minDistance = 99999999; int centerIndex = 0; for(int i = 0;i < 4;i++){ double itsDistance = (x-loc[i][0])*(x-loc[i][0])+(y-loc[i][1])*(y-loc[i][1]); if(itsDistance < minDistance){ minDistance = itsDistance; centerIndex = i; } } number[centerIndex]++; //得到属于4个中心点的个数 return new Tuple2<Integer,Tuple2<Double, Double>>(centerIndex, new Tuple2<Double,Double>(x,y)); // the center's number & data } }); //得到key: 属于某个中心点的序号, value:新中心点的坐标 JavaPairRDD<Integer, Iterable<Tuple2<Double, Double>>> sum_center = data.groupByKey(); //System.out.println(sum_center.collect()); JavaPairRDD<Integer,Tuple2<Double, Double>> Ncenter = sum_center.mapToPair(new PairFunction<Tuple2<Integer, Iterable<Tuple2<Double, Double>>>,Integer,Tuple2<Double, Double>>() { private static final long serialVersionUID = 1L; @Override public Tuple2<Integer, Tuple2<Double, Double>> call(Tuple2<Integer, Iterable<Tuple2<Double, Double>>> a)throws Exception { //System.out.println("i am here**********new center******"); int sum_x = 0; int sum_y = 0; Iterable<Tuple2<Double, Double>> it = a._2; for(Tuple2<Double, Double> i : it) { sum_x += i._1; sum_y +=i._2; } double average_x = sum_x / number[a._1]; double average_y = sum_y/number[a._1]; //System.out.println("**********new center******"+a._1+" "+average_x+","+average_y); return new Tuple2<Integer,Tuple2<Double,Double>>(a._1,new Tuple2<Double,Double>(average_x,average_y)); } }); //将中心点输出 Ncenter.foreach(new VoidFunction<Tuple2<Integer,Tuple2<Double,Double>>>() { private static final long serialVersionUID = 1L; @Override public void call(Tuple2<Integer,Tuple2<Double,Double>> t) throws Exception { new_center[t._1][0] = t._2()._1; new_center[t._1][1] = t._2()._2; System.out.println("the new center: "+ t._1+" "+t._2()._1+" , "+t._2()._2); } }); //判断新的中心点和原来的中心点是否一样,一样的话退出循环得到结果,不一样的话继续循环(这里可以设置一个迭代次数) double distance = 0; for(int i=0;i<4;i++) { distance += (center[i][0]-new_center[i][0])*(center[i][0]-new_center[i][0]) + (center[i][1]-new_center[i][1])*(center[i][1]-new_center[i][1]); } if(distance == 0.0) { //finished for(int j = 0;j<4;j++) { System.out.println("the final center: "+" "+center[j][0]+" , "+center[j][1]); } break; } else { for(int i = 0;i<4;i++) { center[i][0] = new_center[i][0]; center[i][1] = new_center[i][1]; new_center[i][0] = 0; new_center[i][1] = 0; System.out.println("the new center: "+" "+center[i][0]+" , "+center[i][1]); } } } } } 输入:1. centers.txt : 96,826 606,776 474,866 400,768 data.dat: 存放所有点的坐标存放所有点的坐标。 Post author: luyiqu Post link: https://luyiqu.github.io/2018/11/12/Spark-实现kmeans算法/ Copyright Notice: All articles in this blog are licensed under CC BY-NC-SA 3.0 unless stating additionally.