티스토리 뷰

(2013년에 작성했던 코드 백업을 위함 - 추후 코드 수정 예정)

 

import java.util.ArrayList;
 
 
public class _5_BPNN_BK_First {
 
    private static int ni;
    private static int nh;
    private static int no;
    private static double[] ai;
    private static double[] ah;
    private static double[] ao;
    private static double[][] wi;
    private static double[][] wo;
    private static double[][] ci;
    private static double[][] co;
   
    public static double rand(double a, double b){
       
        double ran;
        ran = (b-a)*Math.random()+a;
        double result = ran;
       
        return result;
    }
   
    public static void _init_(int init_ni, int init_nh, int init_no){
        ni = init_ni+1;
        nh = init_nh;
        no = init_no;
       
        ai = new double[ni];    for(int i =0; i<ni; i++)    ai[i] = 1.0;
        ah = new double[nh];    for(int i =0; i<nh; i++)    ah[i] = 1.0;
        ao = new double[no];    for(int i =0; i<no; i++)    ao[i] = 1.0;
       
        wi = new double[ni][nh];    for(int i =0; i<ni; i++)    for(int j =0; j<nh; j++)    wi[i][j] = rand(-0.2, 0.2);
        wo = new double[nh][no];    for(int i =0; i<nh; i++)    for(int j =0; j<no; j++)    wo[i][j] = rand(-0.2, 0.2);
       
        ci = new double[ni][nh];
        co = new double[nh][no];
    }
   
    public static double sigmoid(double x){  
        return x = Math.tanh(x);       
    }
   
    public static double dsigmoid(double y){
       
        y = 1.0-Math.pow(y, 2);
        return y;
    }
   
    public static void update(double[] input){
        if(input.length != (ni-1)){   
            System.out.println("wrong number of inputs");
            System.gc();    // 메모리회수(Garbage Collection)
        }else{
           
            double sum;
           
            // input activations
            for(int i=0; i<(ni-1); i++)
                ai[i] = input[i];
            // End - input
           
           
            // hidden activations
            for(int i=0; i<nh; i++){
                sum = 0.0;
                for(int j=0; j<ni; j++)
                    sum = sum+ai[j]*wi[j][i];
                ah[i] = sigmoid(sum);
            }
            // End - hidden
           
            // output activations
            for(int i=0; i<no; i++){
                sum = 0.0;
                for(int j=0; j<nh; j++)
                    sum = sum+ah[j]*wo[j][i];
                ao[i] = sigmoid(sum);
            }
            // End - output
           
            System.gc();    // 메모리 회수(Garbage Collection)
        }
    }
   
    public static double backPropagate(String str_target, int p, double N, double M){
       
        //System.out.println("backPropagate ");
        double error = 0.0;
        String target[] = str_target.split("\^");
       
        if(target.length != (double)no){
            System.out.println("wrong number of target values" + target.length +":"+ no);
        }else{
           
            //calculate error terms for output
            double output_deltas[] = new double[no];
            for(int i=0; i<no; i++){
                error = Double.parseDouble(target[i])-ao[i];
                output_deltas[i] = dsigmoid(ao[i])*error;
            }
            // End output
            
            // calculate error terms for hidden
            double hidden_deltas[] = new double[nh];
            for(int i=0; i<nh; i++){
                error = 0.0;
                for(int j=0; j<no; j++){
                    error = error+output_deltas[j]*wo[i][j];
                }
                hidden_deltas[i] = dsigmoid(ah[i])*error;
            }
            // End hidden
            
            // update output weights
            double change = 0.0;
            for(int i=0; i<nh; i++){
                for(int j=0; j<no; j++){
                    change = output_deltas[j]*ah[i];
                    wo[i][j] = wo[i][j] + N*change + M*co[i][j];
                    co[i][j] = change;
                    //System.out.println((N*change) + " | " + (M*co[i][j]));
                }
            }
            // End output weight
            
            // update input weights
            for(int i=0; i<ni; i++){
                for(int j=0; j<nh; j++){
                    change = hidden_deltas[j]*ai[i];
                    wi[i][j] = wi[i][j] + N*change + M*ci[i][j];
                    ci[i][j] = change;
                }
            }
            // End input weight
            
            // calculate error
            error = 0.0;
            for(int i=0; i<target.length; i++){
                error = error + 0.5*Math.pow((Double.parseDouble(target[i])-ao[i]),2);
            }
            // End error
        }
        System.gc();
        return error;
    }
   
    public static void train(double patterns[][], String targets[], int iterations, double N, double M){
 
        double error = 0.0;
       
        for(int i=0; i<iterations || error < 0.002; i++){
            error = 0.0;
            for(int p=0; p<patterns.length; p++){
                double inputs[] = new double[patterns[p].length];
                //double target;
 
                for(int pp=0; pp<patterns[p].length; pp++){
                    inputs[pp] = patterns[p][pp];
                    //System.out.println("inputs[" + pp + "]" + inputs[pp]);
                }
               
                //target = targets[p];
                update(inputs);
                //System.out.println(targets[p]);
                //String sprit_target[] = targets[p].split("^");
                error = error + backPropagate(targets[p], p, N, M);
            }
           
           
            //if(error < 0.002)
                //i = iterations;
           
            if(i%100 == 0)
                System.out.println(error);
        }
        System.gc();
    }
 
    public static void test(double[][] Ts_patterns){
        for(int i=0; i<Ts_patterns.length; i++){
            double inputs[] = new double[Ts_patterns[i].length];
           
            for(int j=0; j<Ts_patterns[0].length; j++)
                inputs[j] = Ts_patterns[i][j];
           
            update(inputs);
           
            for(int j=0; j<ao.length; j++)
                System.out.print(ao[j] + " " );
            System.out.println(" ");
        }
        System.gc();
    }
   
    public static void weights(){
        // 가중치 확인
        System.out.println("input weights: ");
        for (int i=0; i<wi.length; i++)
            for (int j=0; j<wi[0].length; j++) 
                System.out.println(wi[i][j]);
       
       
       
        System.out.println("Output weights: ");
        for (int i=0; i<wo.length; i++)
            for (int j=0; j<wo[0].length; j++)
                System.out.println(wo[i][j]);
        System.gc();
    }
 
    public static void main(String[] args) throws Exception{
       
        /* 데이터 입력 및 변수 초기화 */
       
        // 입력 값 및 변수 초기화
        String Sample = "0,0,0,0,0^0^0^1#0,0,0,1,0^0^1^0#0,0,1,0,0^1^0^0#0,0,1,1,1^0^0^0";
        int input_layer = 0;       
        int hidden_layer = 5;
        int output_layer = 4;
   
   
        // 입력 값에 따른 배열 초기화
        String extraction[] = Sample.split("\#"); 
        String[] Arr = extraction[0].split(",");
        int SizeOfArr = Arr.length;
        double patterns[][] = new double[extraction.length][SizeOfArr-1];
        String targets[] = new String[extraction.length];
 
        // 배열 값 삽입
        for(int i=0; i<extraction.length; i++){
            String tmp_extract[] = extraction[i].split(",");
            for(int j=0; j<tmp_extract.length-1 ;j++)
                patterns[i][j] = Double.parseDouble(tmp_extract[j]);
            targets[i] = tmp_extract[tmp_extract.length-1];
           
            // input Layer init-
            if (i==0) input_layer = tmp_extract.length-1;
        }
        //  데이터 입력 및 변수 초기화 완료*/
       
       
        /* 입력 값 확인 */
        for(int i=0; i<patterns.length; i++){
            for(int j=0; j<patterns[0].length; j++){
                System.out.print(j+":"+patterns[i][j] + " ");}
            System.out.println(" "+targets[i]);
            System.out.println(" ");}
        // End 입력 값 확인*/
       
        // 레이어 초기화 
       
        System.out.println("Initalizing");
        _init_(input_layer, hidden_layer,output_layer);
       
        // 매개변수 초기화
        int iterations = 1000;
        double N = 0.5;
        double M = 0.1;
       
        // 트레이닝
        train(patterns, targets, iterations, N, M);
 
        
        String Ts_Sample = "0,0,0,0,0^0^0^1|0,0,0,1,0^0^1^0|0,0,1,0,0^1^0^0|0,0,1,1,1^0^0^0";
        String Ts_extraction[] = Ts_Sample.split("\|"); 
        String[] Ts_Arr = Ts_extraction[0].split(",");
        int Ts_SizeOfArr = Ts_Arr.length;
        double Ts_patterns[][] = new double[Ts_extraction.length][Ts_SizeOfArr-1];
        String Ts_targets[] = new String[Ts_extraction.length];
 
        // 배열 값 삽입
        for(int i=0; i<Ts_extraction.length; i++){
            String Ts_tmp_extract[] = Ts_extraction[i].split(",");
            for(int j=0; j<Ts_tmp_extract.length-1 ;j++)
                Ts_patterns[i][j] = Double.parseDouble(Ts_tmp_extract[j]);
            Ts_targets[i] = Ts_tmp_extract[Ts_tmp_extract.length-1];
        }
        test(Ts_patterns);
    }
 
}

 

 

https://blog.naver.com/johoonx2/110167782018

 

Java 기반의 BPNN(인공신경망) by 조훈

입력값/출력값 Sample = "0,0,0,0,0^0^0^1#0,0,0,1,0^0^1^0#0,0,1,0,0^1^0^0#0,0,1,1,1^0^0^0"; 0,0,0,0,0...

blog.naver.com

 

공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/05   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
글 보관함