In [None]:
# imports
import numpy as np
import matplotlib.pyplot as plt
import random

In [None]:
class NerualNetTenth:
    '''Class for solving the a simple neural net: Selecting the tenth element of a set'''
    
    def __init__(self,nInX=10,nTrain=10,nRepeat=100,verbose=True):
        
        num_set=[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
        
        self.num_set=num_set # Values to choose from for input
    
        self.nInX=nInX # Size of the input array
        self.nTrain=nTrain # Number of unique training data
        self.nRepeat=nRepeat # Number of times to repeat training set
        self.nTrain_rep=nTrain*nRepeat # Number of total training sets
        self.verbose=verbose # Whether to print out info
        
        # Set the alpha value
        self.alpha=self.set_alpha()
        
        # Set up the training set 
        self.xT,self.zT,self.xT_rep,self.zT_rep=self.get_training_set()
        
        # Train the model to get A
        self.A=self.train_neural_net()
        
        # Test on the training set to get residual
        self.resid=[]
        for x in range(nTrain):
            gAx=self.apply(self.xT[x])
            self.resid.append(gAx-self.zT[x])
        self.resid=np.ndarray.flatten(np.array(self.resid))

    def set_alpha(self):
        '''Set the alpha value based on the largest value in the set of numbers'''
        return 10./(self.nInX*np.amax(np.abs(np.array(self.num_set))))
        
    def get_training_set(self):
        '''Generate a random training set of nTrain unique input/output pairs repeated nRepeat times.'''

        xT=[]
        zT=[]
        if self.verbose: print('Training data:')
        for t in range(self.nTrain):
            x=np.random.choice(self.num_set,size=self.nInX)
            zT.append(x[9]) # Tenth element
            xT.append(list(np.array(x)-0.5)) # Subtract 0.5 to place in the nonlinear region of g
            
            if self.verbose:
                print('x: ',np.round(xT[-1],1),'z: ',np.round(zT[-1],1))
        
        # Duplicate the training set
        xT_rep=xT.copy()
        zT_rep=zT.copy()
        for d in range(self.nRepeat):
            xT_rep+=xT.copy()
            zT_rep+=zT.copy()
        
        # Shuffle the training set
        _tmp_xz=list(zip(xT_rep,zT_rep))
        random.shuffle(_tmp_xz)
        xT_rep,zT_rep=zip(*_tmp_xz)
        
        return xT,zT,xT_rep,zT_rep    
    
    def train_neural_net(self):
        '''Train the neural network'''
        # Tolerance and step size for convergence of the steepest descent
        tol=1e-8
        eta=0.5
        
        # Randomly initialize A matrix, 1 x size of nInX
        A=np.random.rand(1,self.nInX)

        for k in range(self.nTrain_rep):

            yk=self.zT_rep[k]
            zk=self.g(self.alpha,np.dot(A,self.xT_rep[k]))
            
            while((zk-yk)**2 > tol):
                
                delA=2*self.alpha*(zk-yk)*zk*(1.-zk)*self.xT_rep[k]
                A-=eta*delA
                zk=self.g(self.alpha,np.dot(A,self.xT_rep[k]))        
            
            if self.verbose:
                print('Training data: ',k,'Error: ',(zk-yk)**2)
          
            
        return A
    
    def g(self,alpha,x):
        '''Sigmoid Nonlinear function'''
        return 1/(1+np.exp(-alpha*x))
    
    def apply(self,x):
        '''Apply the nerual net to some x'''
        return self.g(self.alpha,np.dot(self.A,np.array(x)))
    
    
    def new_data_set(self,nData):
        '''Test our model on a new set of data'''
        xT=[]
        zT=[]
        resid=[]
        for t in range(nData):
            x=np.random.choice(self.num_set,10)
            resid.append(self.apply(list(np.array(x)-0.5))-x[9])
            
            if self.verbose:
                print('New data: ',x,'Error: ',resid[-1])
            
        return resid
    

In [None]:
# Set up and train the neural network 
nn=NerualNetTenth()

Training data:  221 Error:  [9.91766097e-09]
Training data:  222 Error:  [9.16469548e-09]
Training data:  223 Error:  [9.91830894e-09]
Training data:  224 Error:  [9.58708095e-09]
Training data:  225 Error:  [9.9328349e-09]
Training data:  226 Error:  [9.66449772e-09]
Training data:  227 Error:  [9.94801805e-09]
Training data:  228 Error:  [9.90647996e-09]
Training data:  229 Error:  [9.7243247e-09]
Training data:  230 Error:  [9.88968274e-09]
Training data:  231 Error:  [9.98185599e-09]
Training data:  232 Error:  [9.92829211e-09]
Training data:  233 Error:  [9.77501352e-09]
Training data:  234 Error:  [9.78787301e-09]
Training data:  235 Error:  [9.84082066e-09]
Training data:  236 Error:  [9.60609042e-09]
Training data:  237 Error:  [4.5349069e-09]
Training data:  238 Error:  [9.96523561e-09]
Training data:  239 Error:  [9.1811904e-09]
Training data:  240 Error:  [9.61241455e-09]
Training data:  241 Error:  [9.38263908e-09]
Training data:  242 Error:  [9.38263908e-09]
Training data:

Training data:  461 Error:  [9.79571242e-09]
Training data:  462 Error:  [9.66966777e-09]
Training data:  463 Error:  [9.66966777e-09]
Training data:  464 Error:  [9.97013486e-09]
Training data:  465 Error:  [9.9432471e-09]
Training data:  466 Error:  [9.9634305e-09]
Training data:  467 Error:  [9.87665618e-09]
Training data:  468 Error:  [9.89540263e-09]
Training data:  469 Error:  [9.76771019e-09]
Training data:  470 Error:  [9.93875718e-09]
Training data:  471 Error:  [9.91924655e-09]
Training data:  472 Error:  [9.69403194e-09]
Training data:  473 Error:  [2.07603777e-10]
Training data:  474 Error:  [8.98866004e-09]
Training data:  475 Error:  [9.98400829e-09]
Training data:  476 Error:  [9.75363941e-09]
Training data:  477 Error:  [9.31525972e-09]
Training data:  478 Error:  [9.19521404e-09]
Training data:  479 Error:  [9.76057992e-09]
Training data:  480 Error:  [9.94577941e-09]
Training data:  481 Error:  [9.5217358e-09]
Training data:  482 Error:  [9.82385453e-09]
Training data

Training data:  677 Error:  [9.99285943e-09]
Training data:  678 Error:  [9.99285943e-09]
Training data:  679 Error:  [9.89998016e-09]
Training data:  680 Error:  [9.68683945e-09]
Training data:  681 Error:  [9.84419829e-09]
Training data:  682 Error:  [9.94576525e-09]
Training data:  683 Error:  [9.78604098e-09]
Training data:  684 Error:  [9.76537309e-09]
Training data:  685 Error:  [9.76537309e-09]
Training data:  686 Error:  [9.79378073e-09]
Training data:  687 Error:  [9.9675824e-09]
Training data:  688 Error:  [9.92134668e-09]
Training data:  689 Error:  [9.81148997e-09]
Training data:  690 Error:  [9.90691989e-09]
Training data:  691 Error:  [9.77470069e-09]
Training data:  692 Error:  [9.22309044e-09]
Training data:  693 Error:  [9.92628018e-09]
Training data:  694 Error:  [9.15605486e-09]
Training data:  695 Error:  [9.15605486e-09]
Training data:  696 Error:  [9.46200768e-09]
Training data:  697 Error:  [9.95639448e-09]
Training data:  698 Error:  [9.6375038e-09]
Training dat

Training data:  899 Error:  [9.75736638e-09]
Training data:  900 Error:  [9.66516013e-09]
Training data:  901 Error:  [9.76399684e-09]
Training data:  902 Error:  [9.74890866e-09]
Training data:  903 Error:  [9.85878983e-09]
Training data:  904 Error:  [9.82847366e-09]
Training data:  905 Error:  [9.82847366e-09]
Training data:  906 Error:  [9.85259886e-09]
Training data:  907 Error:  [9.91932162e-09]
Training data:  908 Error:  [9.92045306e-09]
Training data:  909 Error:  [9.62084972e-09]
Training data:  910 Error:  [9.80147078e-09]
Training data:  911 Error:  [9.80147078e-09]
Training data:  912 Error:  [9.72817179e-09]
Training data:  913 Error:  [9.8262928e-09]
Training data:  914 Error:  [9.64120264e-09]
Training data:  915 Error:  [9.53562328e-09]
Training data:  916 Error:  [9.68408615e-09]
Training data:  917 Error:  [9.70087096e-09]
Training data:  918 Error:  [9.65004418e-09]
Training data:  919 Error:  [9.94811545e-09]
Training data:  920 Error:  [9.94811545e-09]
Training da

In [None]:
# Test to see that the training was perfomed correctly by plotting the residuals of the training set
fig,ax=plt.subplots(figsize=(5,5))

ax.plot(nn.resid,'+',markersize=10)
ax.set_xlabel('Trial data')
ax.set_ylabel('Residual')

plt.savefig('train_resid.pdf',bbox_inches='tight')
plt.show()

In [None]:
# Print out the A matrix
np.set_printoptions(precision=2)
print(nn.A)

In [None]:
# Test on a new set of data
resid=nn.new_data_set(100)

fig,ax=plt.subplots(figsize=(5,5))

ax.plot(resid,'+',markersize=10)
ax.set_xlabel('Trial data')
ax.set_ylabel('Residual')

plt.savefig('data_resid.pdf',bbox_inches='tight')
plt.show()    
