Spark 实现kmeans算法

spark 实现K-means算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
package kmeans;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;.
import java.util.Arrays;
import java.util.Iterator;


import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;


import scala.Tuple2;



public class kmeans{
static double[][] center = new double[4][2]; //这里有4个中心点,为2维
static int[] number = new int[4]; //记录属于当前中心点的数据的个数,方便做除法
static double[][] new_center = new double[4][2]; //计算出来的新中心点
public static void main(String[] args) {

// 从文件中读出中心点,并且放入center数组中
ArrayList<String> arrayList = new ArrayList<String>();
try {
File file = new File("/usr/local/hadoop-2.7.3/centers.txt");
InputStreamReader input = new InputStreamReader(new FileInputStream(file));
BufferedReader bf = new BufferedReader(input);
// 按行读取字符串
String str;
while ((str = bf.readLine()) != null) {
arrayList.add(str);
}
bf.close();
input.close();
} catch (IOException e) {
e.printStackTrace();
}
// 对ArrayList中存储的字符串进行处理
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 2; j++) {
String s = arrayList.get(i).split(",")[j];
center[i][j] = Double.parseDouble(s);
}
}


//System.out.println("center+++" + center[3][1]);
SparkConf conf = new SparkConf().setAppName("kmeans").setMaster("local[*]");
JavaSparkContext jsc = new JavaSparkContext(conf);
JavaRDD<String> datas = jsc.textFile("spark/input4/k-means.dat"); //从hdfs上读取data

while(true) {
for (int i = 0; i< 4;i++) //注意每次循环都需要将number[i]变为0
{
number[i]=0;
}
//将data分开,得到key: 属于某个中心点的序号(0/1/2/3),value: 与该中心点的距离
JavaPairRDD<Integer, Tuple2<Double, Double>> data = datas.mapToPair(new PairFunction<String, Integer,Tuple2<Double, Double>>() {
private static final long serialVersionUID = 1L;
@Override
public Tuple2<Integer,Tuple2<Double, Double>> call(String str) throws Exception {
final double[][] loc = center;
String[] datasplit = str.split(",");
double x = Double.parseDouble(datasplit[0]);
double y = Double.parseDouble(datasplit[1]);
double minDistance = 99999999;
int centerIndex = 0;
for(int i = 0;i < 4;i++){
double itsDistance = (x-loc[i][0])*(x-loc[i][0])+(y-loc[i][1])*(y-loc[i][1]);
if(itsDistance < minDistance){
minDistance = itsDistance;
centerIndex = i;
}
}
number[centerIndex]++; //得到属于4个中心点的个数

return new Tuple2<Integer,Tuple2<Double, Double>>(centerIndex, new Tuple2<Double,Double>(x,y));
// the center's number & data
}
});

//得到key: 属于某个中心点的序号, value:新中心点的坐标
JavaPairRDD<Integer, Iterable<Tuple2<Double, Double>>> sum_center = data.groupByKey();
//System.out.println(sum_center.collect());

JavaPairRDD<Integer,Tuple2<Double, Double>> Ncenter = sum_center.mapToPair(new PairFunction<Tuple2<Integer, Iterable<Tuple2<Double, Double>>>,Integer,Tuple2<Double, Double>>() {
private static final long serialVersionUID = 1L;
@Override
public Tuple2<Integer, Tuple2<Double, Double>> call(Tuple2<Integer, Iterable<Tuple2<Double, Double>>> a)throws Exception {
//System.out.println("i am here**********new center******");
int sum_x = 0;
int sum_y = 0;
Iterable<Tuple2<Double, Double>> it = a._2;

for(Tuple2<Double, Double> i : it) {
sum_x += i._1;
sum_y +=i._2;
}

double average_x = sum_x / number[a._1];
double average_y = sum_y/number[a._1];
//System.out.println("**********new center******"+a._1+" "+average_x+","+average_y);
return new Tuple2<Integer,Tuple2<Double,Double>>(a._1,new Tuple2<Double,Double>(average_x,average_y));
}
});


//将中心点输出
Ncenter.foreach(new VoidFunction<Tuple2<Integer,Tuple2<Double,Double>>>() {
private static final long serialVersionUID = 1L;
@Override
public void call(Tuple2<Integer,Tuple2<Double,Double>> t) throws Exception {
new_center[t._1][0] = t._2()._1;
new_center[t._1][1] = t._2()._2;
System.out.println("the new center: "+ t._1+" "+t._2()._1+" , "+t._2()._2);
}

});

//判断新的中心点和原来的中心点是否一样,一样的话退出循环得到结果,不一样的话继续循环(这里可以设置一个迭代次数)
double distance = 0;
for(int i=0;i<4;i++) {
distance += (center[i][0]-new_center[i][0])*(center[i][0]-new_center[i][0]) + (center[i][1]-new_center[i][1])*(center[i][1]-new_center[i][1]);
}

if(distance == 0.0) {
//finished
for(int j = 0;j<4;j++) {
System.out.println("the final center: "+" "+center[j][0]+" , "+center[j][1]);
}
break;
}
else {
for(int i = 0;i<4;i++) {
center[i][0] = new_center[i][0];
center[i][1] = new_center[i][1];
new_center[i][0] = 0;
new_center[i][1] = 0;
System.out.println("the new center: "+" "+center[i][0]+" , "+center[i][1]);
}
}
}
}
}

输入:

1. centers.txt :
    96,826
    606,776    
    474,866
    400,768
  1. data.dat:
    存放所有点的坐标存放所有点的坐标。
0%