jimin 发表于 2007-10-13 15:55

隐节点合成算法matlab程序

该程序来源于神经网络结构设计的理论与方法
我将它录入并调试通过
相关的理论部分可以参考书神经网络结果设计的理论与方法
程序较多 输入和调试花了我几天时间
一共一个主函数 7个子函数  
我就不在这里重复贴过来了
http://www.2nsoft.cn/bbs/read.php?tid=7675
有兴趣的可以学习一下
为了方便起见 想转帖贴过来也可以的 

jimin 发表于 2007-10-13 22:51

function main()
indim=5;
outdim=1;
hidden1unitnum=5;
hidden2unitnum=5;
allsamnum=32;
traindatanum=24;
testdatanum=8;
allsamin=[];
for num=0:allsamnum-1
    str=dec2bin(num);
    =size(str);
    vect=[];
    for i=1:len
      ch=str(i);
      vect=;
    end
    if(len<indim)
      vect=;
    end
    allsamin=;
end
allsamout=(allsamin(1,:)|allsamin(2,:))&...
    (allsamin(3,:)|allsamin(4,:)|allsamin(5,:));
permpos=randperm(allsamnum);
traindatain=allsamin(:,permpos(:,1:traindatanum));
traindataout=allsamout(:,permpos(1:traindatanum));
testdatain=allsamin(:,permpos(:,traindatanum+1:traindatanum+testdatanum));
testdataout=allsamout...
    (:,permpos(:,traindatanum+1:traindatanum+testdatanum));
w1=0.5*rands(hidden1unitnum,indim);
b1=0.5*rands(hidden1unitnum,1);
w2=0.5*rands(hidden2unitnum,hidden1unitnum);
b2=0.5*rands(hidden2unitnum,1);
w3=0.5*rands(outdim,hidden2unitnum);
b3=0.5*rands(outdim,1);
lr=0.9;
alpha=0.9;
maxepoch=2000;
errcombine=0.001;
errgoal=0.00005;
unitscombinethreshold=0.8;
biascombinethreshold=0.01;
w1ex=;
w2ex=;
w3ex=;
traindatainex=';
errhistory=[];
resizeflag=1;
for epoch=1:maxepoch
    if(resizeflag==1),
      =size(w2ex);
      hidden1unitnum=hidden1unitnum-1;
      w2=w2ex(:,1:hidden1unitnum);
      w3=w3ex(:,1:hidden2unitnum);
      dw1ex=zeros(size(w1ex));
      dw2ex=zeros(size(w2ex));
      dw3ex=zeros(size(w3ex));
      resizeflag=0;
    end
    hidden1out=logsig(w1ex*traindatainex);
    hidden1outex=';
    hidden2out=logsig(w2ex*hidden1outex);
    hidden2outex=';
    networkout=logsig(w3ex*hidden2outex);
    error=traindataout-networkout;
    sse=sumsqr(error);
    errhistory=;
    if(sse<errcombine),
      hidden1var=var(hidden1out')';
      hidden2var=var(hidden2out')';
      hidden1corr=corrcoef(hidden1out');
      hidden2corr=corrcoef(hidden2out');
      =findunittocombine(hidden1corr,...
            hidden1var,unitscombinethreshold,biascombinethreshold);
      if(hidden1unit1>0),
            if(hidden1unit2>0),
            =linearreg(hidden1out(hidden1unit1,:),...
                hidden1out(hidden1unit2,:));
            epoch
            combinetype=11
            drawcorrelatedunitsout(hidden1out...
                (hidden1unit1,:),hidden1out(hidden1unit2,:));
            =combinetwounits(hidden1unit1,...
                hidden1unit2,a,b,w1ex,w2ex);
      else
            epoch
            combine=12;
            drawbiasedunitout(hidden1out(hidden1unit1,:));
            unitmean=mean(hidden1out(hidden1unit1,:));
            =combineunittobias...
                (hidden1unit1,unitmean,w1ex,w2ex);
      end
      resizeflag=1;
      continue;
    end
    =findunittocombine(hidden2corr,...
      hidden2var,unitscombinethreshold,biascombinethreshold);
    if(hidden2unit1>0),
      if(hidden2unit2>0),
            epoch
            combinetype=21
            =linearreg(hidden2out...
                (hidden2unit1,:),hidden2out(hidden2unit2,:));
               drawcorrelatedunitsout(hidden2out(hidden2unit1,:),...
                   hidden2out(hidden2unit2,:));
               =combinetwounits(hidden2unit1,...
                   hidden2unit2,a,b,w2ex,w3ex);
         else
            epoch
            combinetype=22
            drawbiasedunitout(hidden2out(hidden2unit1,:));
            unitmean=mean(hidden2out(hidden2unit1,:));
            =combineunittobias(hidden2unit1,unitmean,w2ex,w3ex);
      end
       resizeflag=1;
      continue;
    end
end
      
      
if(sse<errgoal),break,end
delta3=error.*networkout.*(1-networkout);
delta2=w3'*delta3.*hidden2out.*(1-hidden2out);
delta1=w2'*delta2.*hidden1out.*(1-hidden1out);
dw1ex0=lr*dw1ex;
dw2ex0=lr*dw2ex;
dw3ex0=lr*dw3ex;
dw3ex=delta3*hidden2outex';
dw2ex=delta2*hidden1outex';
dw1ex=delta1*traindatainex';
w1ex=w1ex+lr*dw1ex+alpha*dw1ex0;
w2ex=w2ex+lr*dw2ex+alpha*dw2ex0;
w3ex=w3ex+lr*dw3ex+alpha*dw3ex0;
w2=w2ex(:,1:hidden1unitnum);
w3=w3ex(:,1:hidden2unitnum);
end
hidden1unitnum
hidden2unitnum
w1=w1ex(:,1:indim);
b1=w1ex(:,indim+1);
w2=w2ex(:,1:hidden1unitnum);
b2=w2ex(:,hidden1unitnum+1);
w3=w3ex(:,1:hidden2unitnum);
b3=w3ex(:,hidden2unitnum+1);
testnnout=bpnet(testdatain,w1,b1,w2,b2,w3,b3);
binout=testnnout>0.5;
errnum=sum(testnnout-binout)
figure
echo off
axis on
grid
hold on
=size(errhistory);
semilogy(1:num,errhistory,'r-');
plot(1:num,errhistory,'r-');
function =findunittocombine(hiddencorr,hiddenvar,...
    unitscombinethreshold,biascombinethreshold)
corrtri=triu(hiddencorr)-eye(size(hiddencorr));
while(1)
    =max(abs(corrtri));
    =max(val);
    if(maxcorr<unitscombinethreshold)
      unit1=0;unit2=0;
      break
    end
    unit1=pos(unit2);
    if(hiddenvar(unit1)>biascombinethreshold &...
            hiddenvar(unit2)>biascombinethreshold)
      break
    else
      corrtri(unit1,unit2)=0;
    end
end
if(unit1>0)return;
end
=min(hiddenvar);
if(minvar<biascombinethreshold)
    unit1=unit;
    unit2=0;
end

function =linearreg(vect1,vect2)
=size(vect1);
meanv1=mean(vect1);
meanv2=mean(vect2);
a=(vect1*vect2'/n-meanv1*meanv2)/(vect1*vect2'/n-meanv1^2);
b=meanv2-a*meanv1;
function out=bpnet(in,w1,b1,w2,b2,w3,b3)
=size(in);
hidden1out=logsig(w1*in+repmat(b1,1,innum));
hidden2out=logsig(w2*hidden1out+repmat(b2,1,innum));
out=logsig(w3*hidden2out+repmat(b3,1,innum));

function =combinetwounits(unit1,unit2,a,b,w1ex,w2ex)
=size(w2ex);
w2ex(:,unit1)=w2ex(:,unit1)+a*w2ex(:,unit2);
w2ex(:,biascol)=w2ex(:,biascol)+b*w2ex(:,unit2);
w1ex(unit2,:)=[];
w2ex(:,unit2)=[];

function =combineunittobias(unit,unitmean,w1ex,w2ex)
=size(w2ex);
w2ex(:,biascol)=w2ex(:,biascol)+unitmean*w2ex(:,unit);
w1ex(unit,:)=[];
w2ex(:,unit)=[];

function drawcorrelatedunitsout(unitout1,unitout2)
=size(unitout1);
figure
echo off
axis()
axis on
grid
hold on
plot(1:ptnum,unitout1,'b-')
plot(1:ptnum,unitout2,'k-')

functiondrawbiasedunitout(unitout)
=size(unitout);
figure('position',)
echo off
axis()
axis on
grid
hold on
页: [1]
查看完整版本: 隐节点合成算法matlab程序