Skip to main content

MATLAB cross validation

// use built-in function
samplesize = size( matrix , 1);
c = cvpartition(samplesize,  'kfold' , k); % return the indexes on each fold

///// output in matlab console
K-fold cross validation partition
             N: 10
   NumTestSets: 4
     TrainSize: 8  7  7  8
      TestSize: 2  3  3  2
//////////////////////

for i=1 : k
   trainIdxs = find(training(c,i) ); %training(c,i);  // 1 means in train , 0 means in test
   testInxs  = find(test(c,i)       ); % test(c,i);       // 1 means in test , 0 means in train

   trainMatrix = matrix (  matrix(trainIdxs ), : );
   testMatrix  = matrix (  matrix(testIdxs  ), : );
end

//// now calculate performance


%%  calculate performance of a partition
    selectedKfoldSen=[];selectedKfoldSpe=[];selectedKfoldAcc=[];
    indexSen=1;indexSpe=1;indexAcc=1;
    if ( kfold == (P+N) )% leave one out
        sensitivity = sum(cvtp) /( sum(cvtp) + sum(cvfn) )
        specificity = sum(cvtn) /( sum(cvfp) + sum(cvtn) )
        accuracy = (sum(cvtp)+sum(cvtn)) / ( sum(cvtp) + sum(cvfn) + sum(cvfp) + sum(cvtn) )
       
    else
       
        sensitivity=[]; specificity=[];accuracy=[];
        for i=1: kfold
            if( ( cvtp(i) + cvfn(i) )==0) % no POSITIVE sample was selected for evaluation
                % sensitivity(i) = 1 ;
            else
                sensitivity(indexSen) = cvtp(i) /( cvtp(i) + cvfn(i) ) ;     
                indexSen = indexSen + 1;
                selectedKfoldSen = [selectedKfoldSen i];
            end
           
            if ( cvfp(i) + cvtn(i) ) ==0 % no POSITIVE sample was selected for evaluation
                   %  specificity(i)=  1 ;
            else
                specificity(indexSpe)=  cvtn(i) /( cvfp(i) + cvtn(i) ) ;
                indexSpe = indexSpe + 1;
                selectedKfoldSpe = [selectedKfoldSpe i];
            end
            accuracy(i) = (cvtp(i)+ cvtn(i)) / ( cvtp(i) + cvfn(i) + cvfp(i) + cvtn(i) );
        end
       
        sen = mean(sensitivity)
        spe = mean(specificity)
        acc = mean(accuracy)
       
    end

   
    dlmwrite('cv',[ cvtp' ] , 'delimiter','\t','-append');
    dlmwrite('cv',[ cvfn' ] , 'delimiter','\t','-append');
    dlmwrite('cv',[ cvfp' ] , 'delimiter','\t','-append');
    dlmwrite('cv',[ cvtn']  , 'delimiter','\t','-append');
   
     dlmwrite('cv',[ selectedKfoldSen]  , 'delimiter','\t','-append');
    dlmwrite('cv',[ selectedKfoldSpe]  , 'delimiter','\t','-append');
   
    dlmwrite('cv',[ sensitivity]  , 'delimiter','\t','-append');
    dlmwrite('cv',[ specificity]  , 'delimiter','\t','-append');
    dlmwrite('cv',[ accuracy]     , 'delimiter','\t','-append');

Comments

  1. this code is for k-fold cross validation? if i want to apply it in the neural network,specifically MLP,which part of coding should i add this?

    ReplyDelete
  2. hallo in case my data set has is n-dimensional column vectors what command should i use to partition my data? i tried "crossvalind" to generate indices to each column vector but i get an error.

    ReplyDelete
  3. what are the definitions of "cvtp", "cvfn"? I can not read along when I hit that 2 parameters.

    ReplyDelete
  4. undefined variable cvfn ,cvtp,cvfp and cvtn !!!!!

    ReplyDelete
  5. Your code has a mistake, correct code:

    samplesize = size( matrix , 1);
    c = cvpartition(samplesize, 'kfold' , k); % return the indexes on each fold

    % ///// output in matlab console
    % K-fold cross validation partition
    % N: 10
    % NumTestSets: 4
    % TrainSize: 8 7 7 8
    % TestSize: 2 3 3 2
    % //////////////////////

    for i=1 : k
    trainIdxs = find(training(c,i) ); %training(c,i); // 1 means in train , 0 means in test
    testIdxs = find(test(c,i) ); % test(c,i); // 1 means in test , 0 means in train

    trainMatrix = matrix (trainIdxs, : );
    testMatrix = matrix (testIdxs, : );
    end

    ReplyDelete

Post a Comment

Popular posts from this blog

R tutorial

Install R in linux ============ In CRAN home page, the latest version is not available. So, in fedora, Open the terminal yum list R  --> To check the latest available version of r yum install R --> install R version yum update R --> update current version to latest one 0 find help ============ ?exact topic name (  i.e.   ?mean ) 0.0 INSTALL 3rd party package  ==================== install.packages('mvtnorm' , dependencies = TRUE , lib='/home/alamt/myRlibrary/')   #  install new package BED file parsing (Always use read.delim it is the best) library(MASS) #library(ggplot2) dirRoot="D:/research/F5shortRNA/TestRIKEN/Rscripts/" dirData="D:/research/F5shortRNA/TestRIKEN/" setwd(dirRoot) getwd() myBed="test.bed" fnmBed=paste(dirData, myBed, sep="") # ccdsHh19.bed   tmp.bed ## Read bed use read.delim - it is the  best mybed=read.delim(fnmBed, header = FALSE, sep = "\t", quote = &q