
// g++ sdes16v2.cc -o sdes16v2
 
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdlib.h>
 
#define BYTE unsigned char
#define DoubleBYTE unsigned short int
 
char key[49];
char *K1 = &key[0];
char *K2 = &key[16];
char *K3 = &key[32];
 
void SetKey(unsigned long int SetMe, int Mode)
{
            int i;
            unsigned long int n;
 
            for (i = 0, n = 0x800000000000; i <= 47; i++, n /= 2)
            {
                        if (SetMe & n)
                                    key[i] = '1';
                        else
                                    key[i] = '0';
            }
            key[48] = '\0';

	if (Mode == 2)
	{
		printf("Swapping keys for decrypting mode...\n");
		char *temp = K1;
		K1 = K3;
		K3 = temp;
	}
}
 
int SBox0[4][4] =
{
{1, 0, 3, 2},
{3, 2, 1, 0},
{0, 2, 1, 3},
{3, 1, 3, 2}
};
 
int SBox1[4][4] =
{
{1, 0, 1, 2},
{3, 2, 1, 0},
{3, 2, 1, 0},
{2, 1, 3, 2}
};

int SBox2[4][4] =
{
{1, 3, 3, 2},
{0, 2, 2, 1},
{1, 1, 1, 3},
{2, 1, 0, 2}
};

int SBox3[4][4] =
{
{1, 2, 0, 1},
{2, 0, 1, 0},
{2, 3, 1, 0},
{2, 3, 1, 3}
};

 
#if 0
int SBox1[4][4] =
{
{0, 1, 2, 3},
{2, 0, 1, 3},
{3, 0, 1, 0},
{2, 1, 0, 3}
};

int SBox2[4][4] =
{
{2, 1, 2, 3},
{1, 2, 3, 1},
{0, 0, 3, 0},
{3, 1, 1, 0}
};

int SBox3[4][4] =
{
{1, 2, 0, 1},
{2, 2, 1, 3},
{0, 3, 1, 0},
{2, 3, 0, 3}
};
#endif
 
BYTE BinaryToByte(char *Binary)
{
            int n, i;
            BYTE ReturnMe = 0;
 
            for (i = 0, n = 128; i <= 7; i++, n /= 2)
            {
                        switch(Binary[i])
                        {
                        case '0': break;
                        case '1':
                                    ReturnMe += n;
                        break;
                        default:
                                    puts("BinaryToByte failed");
                                    exit(0);
                        break;
                        }
            }
 
            return ReturnMe;
}
 
void ByteToBinary(char *pBinary, BYTE TheByte)
{
            int i;
            BYTE n;
 
            for (i = 0, n = 0x80; i <= 7; i++, n >>= 1)
            {
                        if (TheByte & n)
                                    pBinary[i] = '1';
                        else
                                    pBinary[i] = '0';
            }
            pBinary[8] = '\0';
}

DoubleBYTE BinaryToDoubleByte(char *Binary)
{
            int n, i;
            DoubleBYTE ReturnMe = 0;
 
            for (i = 0, n = 32768; i <= 15; i++, n /= 2)
            {
                        switch(Binary[i])
                        {
                        case '0': break;
                        case '1':
                                    ReturnMe += n;
                        break;
                        default:
                                    puts("BinaryToDoubleByte failed");
                                    exit(0);
                        break;
                        }
            }
 
            return ReturnMe;
}
 
void DoubleByteToBinary(char *pBinary, DoubleBYTE TheDoubleByte)
{
            int i;
            DoubleBYTE n;
 
            for (i = 0, n = 32768; i <= 15; i++, n >>= 1)
            {
                        if (TheDoubleByte & n)
                                    pBinary[i] = '1';
                        else
                                    pBinary[i] = '0';
            }
            pBinary[16] = '\0';
}

 
void S0(char *pDest, char *pSrc)
{
            int Row = (pSrc[0] - '0') * 2 + (pSrc[3] - '0');
            int Col = (pSrc[1] - '0') * 2 + (pSrc[2] - '0');
            BYTE LookUp = SBox0[Row][Col];
           
            pDest[0] = '0' + ((LookUp & 0x02) >> 1);
            pDest[1] = '0' + (LookUp & 0x01);
            pDest[2] = '\0';
}
 
void S1(char *pDest, char *pSrc)
{
            int Row = (pSrc[0] - '0') * 2 + (pSrc[3] - '0');
            int Col = (pSrc[1] - '0') * 2 + (pSrc[2] - '0');
            BYTE LookUp = SBox1[Row][Col];
           
            pDest[0] = '0' + ((LookUp & 0x02) >> 1);
            pDest[1] = '0' + (LookUp & 0x01);
            pDest[2] = '\0';
}

void S2(char *pDest, char *pSrc)
{
            int Row = (pSrc[0] - '0') * 2 + (pSrc[3] - '0');
            int Col = (pSrc[1] - '0') * 2 + (pSrc[2] - '0');
            BYTE LookUp = SBox2[Row][Col];
           
            pDest[0] = '0' + ((LookUp & 0x02) >> 1);
            pDest[1] = '0' + (LookUp & 0x01);
            pDest[2] = '\0';
}

void S3(char *pDest, char *pSrc)
{
            int Row = (pSrc[0] - '0') * 2 + (pSrc[3] - '0');
            int Col = (pSrc[1] - '0') * 2 + (pSrc[2] - '0');
            BYTE LookUp = SBox3[Row][Col];
           
            pDest[0] = '0' + ((LookUp & 0x02) >> 1);
            pDest[1] = '0' + (LookUp & 0x01);
            pDest[2] = '\0';
}
 
void P8(char *pDest, char *pS0Output, char *pS1Output, char *pS2Output, char *pS3Output)
{
            pDest[0] = pS2Output[0];
            pDest[1] = pS1Output[0];
            pDest[2] = pS3Output[1];
            pDest[3] = pS1Output[1];
            pDest[4] = pS3Output[0];
            pDest[5] = pS0Output[1];
            pDest[6] = pS2Output[1];
            pDest[7] = pS0Output[0];
	    pDest[8] = '\0';
}

void XOR8(char *pDest, char *Operand1, char *Operand2)
{
            int i;
            BYTE a, b;
 
            for (i = 0; i <= 7; i++)
            {
                        a = Operand1[i] - '0';
                        b = Operand2[i] - '0';
                        pDest[i] = '0' + (a ^ b);
            }
            pDest[8] = '\0';
}
 
void XOR16(char *pDest, char *Operand1, char *Operand2)
{
            int i;
            BYTE a, b;
 
            for (i = 0; i <= 15; i++)
            {
                        a = Operand1[i] - '0';
                        b = Operand2[i] - '0';
                        pDest[i] = '0' + (a ^ b);
            }
            pDest[16] = '\0';
}
 
void EP(char *pDest, char *pSrc)
{
            pDest[0] = pSrc[4];
            pDest[1] = pSrc[3];
            pDest[2] = pSrc[2];
            pDest[3] = pSrc[1];
            pDest[4] = pSrc[5];
            pDest[5] = pSrc[6];
            pDest[6] = pSrc[7];
            pDest[7] = pSrc[4];
            pDest[8] = pSrc[0];
            pDest[9] = pSrc[3];
            pDest[10] = pSrc[5];
            pDest[11] = pSrc[6];
            pDest[12] = pSrc[0];
            pDest[13] = pSrc[2];
            pDest[14] = pSrc[1];
            pDest[15] = pSrc[7];
	    
	    pDest[16] = '\0';
}
 
void f(char *pDest, char *pSrc, char *Subkey)
{
            char *pLeft = &pSrc[0], *pRight = &pSrc[8];
            char OutputFromEP[17];
            char OutputFromXOR16[17];
            DoubleBYTE XOR16Result;
            char *pS0Input = &OutputFromXOR16[0];
            char *pS1Input = &OutputFromXOR16[4];
            char *pS2Input = &OutputFromXOR16[8];
            char *pS3Input = &OutputFromXOR16[12];
            char S0Output[3];
            char S1Output[3];
            char S2Output[3];
            char S3Output[3];
            char P8Output[9];
            char XOR8Output[9];
 
            EP(OutputFromEP, pRight);
 
            XOR16Result = BinaryToDoubleByte(OutputFromEP) ^ BinaryToDoubleByte(Subkey);
            DoubleByteToBinary(OutputFromXOR16, XOR16Result);
            S0(S0Output, pS0Input);
            S1(S1Output, pS1Input);
            S2(S2Output, pS2Input);
            S3(S3Output, pS3Input);
 
            P8(P8Output, S0Output, S1Output, S2Output, S3Output);
            XOR8(XOR8Output, P8Output, pLeft);
            memcpy(pDest, XOR8Output, 8);
            memcpy(&pDest[8], pRight, 8);
            pDest[16] = '\0';
}
 
void SW(char *pDest, char *pSrc)
{
            pDest[0] = pSrc[8];
            pDest[1] = pSrc[9];
            pDest[2] = pSrc[10];
            pDest[3] = pSrc[11];
            pDest[4] = pSrc[12];
            pDest[5] = pSrc[13];
            pDest[6] = pSrc[14];
            pDest[7] = pSrc[15];
            pDest[8] = pSrc[0];
            pDest[9] = pSrc[1];
            pDest[10] = pSrc[2];
            pDest[11] = pSrc[3];
            pDest[12] = pSrc[4];
            pDest[13] = pSrc[5];
            pDest[14] = pSrc[6];
            pDest[15] = pSrc[7];
	    
	    pDest[16] = '\0';
}
 
DoubleBYTE DoubleSDES(DoubleBYTE nPlaintext)
{
            DoubleBYTE nCiphertext;
            char Plaintext[17];
            char OutOffK1[17];
            char IntofK2[17];
            char OutOffK2[17];
            char IntofK3[17];
            char OutOffK3[17];
 
            DoubleByteToBinary(Plaintext, nPlaintext);
 
            f(OutOffK1, Plaintext, K1);
            SW(IntofK2, OutOffK1);
            f(OutOffK2, IntofK2, K2);
            SW(IntofK3, OutOffK2);
            f(OutOffK3, IntofK3, K3);
            nCiphertext = BinaryToDoubleByte(OutOffK3);
            return nCiphertext;
}
 
int main(int argc, char **argv)
{
	int EvenOdd = 0; unsigned int Input;
        unsigned int Plaintext, Ciphertext;
        unsigned long int TheKey; // Assumes a 64-bit machine
	unsigned int mode = 0;
	FILE *fIn, *fOut;

 
	if (argc < 4) {printf("Usage: %s [e|d] infilename outfilename\n", argv[0]); exit(0);}
	if (!strcmp(argv[1], "e")) mode = 1;
	if (!strcmp(argv[1], "d")) mode = 2;
	if (mode == 0) {printf("Usage: %s [e|d] infilename outfilename\n", argv[0]); exit(0);}
	if ((fIn = (fopen(argv[2], "r"))) == NULL) {printf("Usage: %s [e|d] infilename outfilename\n", argv[0]); exit(0);}
	if ((fOut = (fopen(argv[3], "w"))) == NULL) {printf("Usage: %s [e|d] infilename outfilename\n", argv[0]); exit(0);}


        fprintf(stderr, "You must enter a hexadecimal key, such as 76A5FC22B8FA\n");
        fprintf(stderr, "The last four hex digits (B8FA) are the 16-bit subkey for the last round.\n");
        fprintf(stderr, "Please enter the key you'd like to use:\n");
        scanf("%lx", &TheKey);
        SetKey(TheKey, mode);

	while (((Input = fgetc(fIn)) != EOF) || (EvenOdd != 0))
	{
		if (!EvenOdd)
		{
			EvenOdd = 1;
			Plaintext = Input << 8;
		}
		else
		{
			EvenOdd = 0;
			if (Input != EOF)
				Plaintext |= Input;
			Ciphertext = DoubleSDES(Plaintext);
			fputc((Ciphertext & 0xFF00) >> 8, fOut);
			fputc(Ciphertext & 0xFF, fOut);
		}
	}

	fclose(fIn);
	fclose(fOut);
        printf("\n");
 
        return 0;
}


