티스토리 뷰
(2013년에 작성했던 코드 백업을 위함 - 추후 코드 수정 예정)
import java.util.ArrayList;
public class _5_BPNN_BK_First {
private static int ni;
private static int nh;
private static int no;
private static double[] ai;
private static double[] ah;
private static double[] ao;
private static double[][] wi;
private static double[][] wo;
private static double[][] ci;
private static double[][] co;
public static double rand(double a, double b){
double ran;
ran = (b-a)*Math.random()+a;
double result = ran;
return result;
}
public static void _init_(int init_ni, int init_nh, int init_no){
ni = init_ni+1;
nh = init_nh;
no = init_no;
ai = new double[ni]; for(int i =0; i<ni; i++) ai[i] = 1.0;
ah = new double[nh]; for(int i =0; i<nh; i++) ah[i] = 1.0;
ao = new double[no]; for(int i =0; i<no; i++) ao[i] = 1.0;
wi = new double[ni][nh]; for(int i =0; i<ni; i++) for(int j =0; j<nh; j++) wi[i][j] = rand(-0.2, 0.2);
wo = new double[nh][no]; for(int i =0; i<nh; i++) for(int j =0; j<no; j++) wo[i][j] = rand(-0.2, 0.2);
ci = new double[ni][nh];
co = new double[nh][no];
}
public static double sigmoid(double x){
return x = Math.tanh(x);
}
public static double dsigmoid(double y){
y = 1.0-Math.pow(y, 2);
return y;
}
public static void update(double[] input){
if(input.length != (ni-1)){
System.out.println("wrong number of inputs");
System.gc(); // 메모리회수(Garbage Collection)
}else{
double sum;
// input activations
for(int i=0; i<(ni-1); i++)
ai[i] = input[i];
// End - input
// hidden activations
for(int i=0; i<nh; i++){
sum = 0.0;
for(int j=0; j<ni; j++)
sum = sum+ai[j]*wi[j][i];
ah[i] = sigmoid(sum);
}
// End - hidden
// output activations
for(int i=0; i<no; i++){
sum = 0.0;
for(int j=0; j<nh; j++)
sum = sum+ah[j]*wo[j][i];
ao[i] = sigmoid(sum);
}
// End - output
System.gc(); // 메모리 회수(Garbage Collection)
}
}
public static double backPropagate(String str_target, int p, double N, double M){
//System.out.println("backPropagate ");
double error = 0.0;
String target[] = str_target.split("\^");
if(target.length != (double)no){
System.out.println("wrong number of target values" + target.length +":"+ no);
}else{
//calculate error terms for output
double output_deltas[] = new double[no];
for(int i=0; i<no; i++){
error = Double.parseDouble(target[i])-ao[i];
output_deltas[i] = dsigmoid(ao[i])*error;
}
// End output
// calculate error terms for hidden
double hidden_deltas[] = new double[nh];
for(int i=0; i<nh; i++){
error = 0.0;
for(int j=0; j<no; j++){
error = error+output_deltas[j]*wo[i][j];
}
hidden_deltas[i] = dsigmoid(ah[i])*error;
}
// End hidden
// update output weights
double change = 0.0;
for(int i=0; i<nh; i++){
for(int j=0; j<no; j++){
change = output_deltas[j]*ah[i];
wo[i][j] = wo[i][j] + N*change + M*co[i][j];
co[i][j] = change;
//System.out.println((N*change) + " | " + (M*co[i][j]));
}
}
// End output weight
// update input weights
for(int i=0; i<ni; i++){
for(int j=0; j<nh; j++){
change = hidden_deltas[j]*ai[i];
wi[i][j] = wi[i][j] + N*change + M*ci[i][j];
ci[i][j] = change;
}
}
// End input weight
// calculate error
error = 0.0;
for(int i=0; i<target.length; i++){
error = error + 0.5*Math.pow((Double.parseDouble(target[i])-ao[i]),2);
}
// End error
}
System.gc();
return error;
}
public static void train(double patterns[][], String targets[], int iterations, double N, double M){
double error = 0.0;
for(int i=0; i<iterations || error < 0.002; i++){
error = 0.0;
for(int p=0; p<patterns.length; p++){
double inputs[] = new double[patterns[p].length];
//double target;
for(int pp=0; pp<patterns[p].length; pp++){
inputs[pp] = patterns[p][pp];
//System.out.println("inputs[" + pp + "]" + inputs[pp]);
}
//target = targets[p];
update(inputs);
//System.out.println(targets[p]);
//String sprit_target[] = targets[p].split("^");
error = error + backPropagate(targets[p], p, N, M);
}
//if(error < 0.002)
//i = iterations;
if(i%100 == 0)
System.out.println(error);
}
System.gc();
}
public static void test(double[][] Ts_patterns){
for(int i=0; i<Ts_patterns.length; i++){
double inputs[] = new double[Ts_patterns[i].length];
for(int j=0; j<Ts_patterns[0].length; j++)
inputs[j] = Ts_patterns[i][j];
update(inputs);
for(int j=0; j<ao.length; j++)
System.out.print(ao[j] + " " );
System.out.println(" ");
}
System.gc();
}
public static void weights(){
// 가중치 확인
System.out.println("input weights: ");
for (int i=0; i<wi.length; i++)
for (int j=0; j<wi[0].length; j++)
System.out.println(wi[i][j]);
System.out.println("Output weights: ");
for (int i=0; i<wo.length; i++)
for (int j=0; j<wo[0].length; j++)
System.out.println(wo[i][j]);
System.gc();
}
public static void main(String[] args) throws Exception{
/* 데이터 입력 및 변수 초기화 */
// 입력 값 및 변수 초기화
String Sample = "0,0,0,0,0^0^0^1#0,0,0,1,0^0^1^0#0,0,1,0,0^1^0^0#0,0,1,1,1^0^0^0";
int input_layer = 0;
int hidden_layer = 5;
int output_layer = 4;
// 입력 값에 따른 배열 초기화
String extraction[] = Sample.split("\#");
String[] Arr = extraction[0].split(",");
int SizeOfArr = Arr.length;
double patterns[][] = new double[extraction.length][SizeOfArr-1];
String targets[] = new String[extraction.length];
// 배열 값 삽입
for(int i=0; i<extraction.length; i++){
String tmp_extract[] = extraction[i].split(",");
for(int j=0; j<tmp_extract.length-1 ;j++)
patterns[i][j] = Double.parseDouble(tmp_extract[j]);
targets[i] = tmp_extract[tmp_extract.length-1];
// input Layer init-
if (i==0) input_layer = tmp_extract.length-1;
}
// 데이터 입력 및 변수 초기화 완료*/
/* 입력 값 확인 */
for(int i=0; i<patterns.length; i++){
for(int j=0; j<patterns[0].length; j++){
System.out.print(j+":"+patterns[i][j] + " ");}
System.out.println(" "+targets[i]);
System.out.println(" ");}
// End 입력 값 확인*/
// 레이어 초기화
System.out.println("Initalizing");
_init_(input_layer, hidden_layer,output_layer);
// 매개변수 초기화
int iterations = 1000;
double N = 0.5;
double M = 0.1;
// 트레이닝
train(patterns, targets, iterations, N, M);
String Ts_Sample = "0,0,0,0,0^0^0^1|0,0,0,1,0^0^1^0|0,0,1,0,0^1^0^0|0,0,1,1,1^0^0^0";
String Ts_extraction[] = Ts_Sample.split("\|");
String[] Ts_Arr = Ts_extraction[0].split(",");
int Ts_SizeOfArr = Ts_Arr.length;
double Ts_patterns[][] = new double[Ts_extraction.length][Ts_SizeOfArr-1];
String Ts_targets[] = new String[Ts_extraction.length];
// 배열 값 삽입
for(int i=0; i<Ts_extraction.length; i++){
String Ts_tmp_extract[] = Ts_extraction[i].split(",");
for(int j=0; j<Ts_tmp_extract.length-1 ;j++)
Ts_patterns[i][j] = Double.parseDouble(Ts_tmp_extract[j]);
Ts_targets[i] = Ts_tmp_extract[Ts_tmp_extract.length-1];
}
test(Ts_patterns);
}
}
https://blog.naver.com/johoonx2/110167782018
공지사항
최근에 올라온 글
최근에 달린 댓글
- Total
- Today
- Yesterday
링크
TAG
- Linux
- pharser3
- 민팅
- 채굴
- nodejs
- node
- minting
- 뱀파이어 서바이벌
- P3X Redis UI
- 이더리움
- 뱀파이어 사바이벌
- 몽고db
- 비트코인
- 이더리움 채굴기
- Vampire Survivor
- OpenSea
- phaser3
- phaser
- remote-ftp
- node.js
- go lang
- 모니터 설정
- 회원 탈퇴
- 지갑 생성
- mysql
- 네이버 클라우드 플랫폼
- GO
- mongodb
- pharser
- krafterspace
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | ||||||
2 | 3 | 4 | 5 | 6 | 7 | 8 |
9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 |
23 | 24 | 25 | 26 | 27 | 28 |
글 보관함