import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Graphics;
import java.awt.Graphics2D;
import java.awt.GridLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;

import javax.swing.JButton;
import javax.swing.JFileChooser;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;


public class MNISTParser {
	
	//data
	private int yVals_[] = null;
	private int xVals_[][][] = null;
	
	
	//vis crap
	private JFrame myFrame_ = null;
	private DigitPanel mainPanel_ = null;
	private JButton nextButton_ = null;
	private JButton prevButton_ = null;
	private JLabel currentYVal_ = null;
	private int currentImage_ = 0;
	
	
	public static void main(String args[]) {
		MNISTParser m = new MNISTParser();
		m.pickFiles();
		m.showVis();
	}
	
	public void pickFiles() {
		JFileChooser jfc = new JFileChooser();
		jfc.setApproveButtonText("Open Y's");
		if(jfc.showOpenDialog(null) == JFileChooser.APPROVE_OPTION) {
			File file = jfc.getSelectedFile();
			System.out.println(file.toString());
			
			try {
				parseYs(file);
			} catch (IOException e) {
				e.printStackTrace();
				System.exit(1);
			}
			
		} else {
			System.exit(1);
		}
		
		jfc.setApproveButtonText("Open X's");
		if(jfc.showOpenDialog(null) == JFileChooser.APPROVE_OPTION) {
			File file = jfc.getSelectedFile();
			System.out.println(file.toString());
			
			try {
				parseXs(file);
			} catch (IOException e) {
				e.printStackTrace();
				System.exit(1);
			}
			
		} else {
			System.exit(1);
		}
		
		jfc.setApproveButtonText("Target File");
		if(jfc.showSaveDialog(null) == JFileChooser.APPROVE_OPTION) {
			try {
				dumpToFile(jfc.getSelectedFile());
			} catch (IOException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
				System.exit(1);
			}
		}
	}
	
	public void showVis() {
		myFrame_= new JFrame("Digit Vis");
		myFrame_.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
		mainPanel_ = new DigitPanel();
		nextButton_ = new JButton(">>");
		nextButton_.addActionListener(mainPanel_);
		prevButton_ = new JButton("<<");
		prevButton_.addActionListener(mainPanel_);
		currentYVal_ = new JLabel("Label = ?");
		
		myFrame_.setLayout(new BorderLayout());
		myFrame_.add(mainPanel_, BorderLayout.CENTER);
		
		JPanel lowerPanel = new JPanel(new GridLayout(1,2));
		lowerPanel.add(prevButton_);
		lowerPanel.add(nextButton_);
		
		myFrame_.add(lowerPanel, BorderLayout.SOUTH);
		myFrame_.add(currentYVal_, BorderLayout.NORTH);
		
		myFrame_.setSize(640,480);
		myFrame_.setResizable(true);
		myFrame_.setVisible(true);
	}
	
	private void parseXs(File file) throws IOException {
		InputStream is = new FileInputStream(file);
		
		//read header crap
		byte[] intBytes = new byte[4];
		
		is.read(intBytes, 0, 4);
		int magicNumber = bytesToInt(intBytes);
		System.out.println("magic number:\t" + magicNumber);
		
		is.read(intBytes, 0, 4);
		int numElements = bytesToInt(intBytes);
		System.out.println("elements:\t" + numElements);
		
		is.read(intBytes, 0, 4);
		int numRows = bytesToInt(intBytes);
		System.out.println("rows:\t" + numRows);
		
		is.read(intBytes, 0, 4);
		int numColumns = bytesToInt(intBytes);
		System.out.println("cols:\t" + numColumns);		
		
		xVals_ = new int[numElements][numColumns][numRows];
		
		for(int n = 0; n < numElements; n++) {
			int imageSize = numRows * numColumns;
			
			//pull down a single image at a time from the stream
			byte currentImage[] = new byte[imageSize];
			is.read(currentImage, 0, imageSize);
			
			for(int i = 0; i < numColumns; i++) {
				for(int j = 0; j < numRows; j++) {
					xVals_[n][j][i] = byteToInt(currentImage[j+i*numColumns]);
				}
			}
			
		}
		
		is.close();
	}
	
	private void parseYs(File file) throws IOException {
		InputStream is = new FileInputStream(file);
		
		//read header crap
		byte[] intBytes = new byte[4];
		
		is.read(intBytes, 0, 4);
		int magicNumber = bytesToInt(intBytes);
		System.out.println("magic number:\t" + magicNumber);
		
		is.read(intBytes, 0, 4);
		int numElements = bytesToInt(intBytes);
		System.out.println("elements:\t" + numElements);
		
		
		//read the rest in
		byte tmpYVals[] = new byte[numElements];
		is.read(tmpYVals, 0, numElements);
		
		//convert them
		yVals_ = new int[numElements];
		for(int n = 0; n < numElements; n++) {
			yVals_[n] = byteToInt(tmpYVals[n]);
			
			if(yVals_[n] > 9) {
				System.err.println("WTF?@" + n + ":\t" + yVals_[n]);
			}
		}
		
		is.close();
	}
		
	/**
	 * Does the ugly byte -> int conversion
	 * 
	 * @param b
	 * @return
	 */
	private static int bytesToInt(byte b[]) {
		int number = 0;
		
		for(int n = 0; n < b.length; n++)
			number = (number << 8) ^ (0x000000FF & (int)b[n]);
		
		return number;
	}
	
	//dumps data to file
	private void dumpToFile(File file) throws IOException {
		BufferedWriter bw = new BufferedWriter(new FileWriter(file));
		
		String headerLine = "lbl";
		
		for(int x = 1; x < 29; x++) {
			for(int y = 1; y < 29; y++) {
				headerLine += "\t" + x + "," + y;
			}
		}
		
		bw.write(headerLine + "\n");
		
		for(int n = 0; n < yVals_.length; n++) {
			String currentLine = "" + yVals_[n];
			
			for(int i = 0; i < xVals_[n].length; i++) {
				for(int j = 0; j < xVals_[n][i].length; j++) {
					currentLine += "\t" + xVals_[n][i][j];
				}
			}
			
			currentLine += "\n";
			
			bw.write(currentLine);
		}
		
		bw.flush();
		bw.close();
	}
	
	
	/**
	 * Don't have unsigned byte, going to use int b/c shorts are for squares
	 * 
	 * @param b
	 * @return
	 */
	private static int byteToInt(byte b) {
		int number = 0;
		
		number = (number << 8) ^ (0x000000FF & (int)b);
		
		return number;
	}
	
	private class DigitPanel extends JPanel implements ActionListener {
		private int pixelSize_ = 10;
		
		public void paintComponent(Graphics g)   {
			   // Paint background
			   super.paintComponent(g);
			   
			   pixelSize_ = Math.min(getWidth()/28, getHeight()/28);

			   
			   Graphics2D g2d = (Graphics2D)g;

			   currentYVal_.setText("Label = " + yVals_[currentImage_] + " for element #" + currentImage_);
			   int image[][] = xVals_[currentImage_];
			   
			   for(int x = 0; x < image.length; x++) {
				   for(int y = 0; y < image[0].length; y++) {
					   setPixel(g2d, x, y, 255 - image[x][y]);
				   }
			   }
			   
		}
		
		private void setPixel(Graphics2D g2d, int x, int y, int tone) {
			g2d.setColor(new Color(tone, tone, tone));
			g2d.fillRect(x*pixelSize_, y*pixelSize_, pixelSize_, pixelSize_);
		}

		public void actionPerformed(ActionEvent e) {
			if(e.getSource() == nextButton_ && currentImage_ < (yVals_.length - 1)) {
				currentImage_++;
			} else if(e.getSource() == prevButton_ && currentImage_ > 0) {
				currentImage_--;
			}
			
			this.repaint();
			
		}
		
	}	
}
