P2P通信の基礎部分

CREA2014(CREACOIN)のプログラムに組み込まれる予定のP2P通信の基礎部分のコードを書き上げました。旧CREAのコードを改良したものです。これ以上の大きな変更はないと思います。後日GitHubのプログラムも更新します。

主な仕様ですが、通信の暗号化には、RSA2048bitとAES256bitを使用できるようにしました。旧CREAでは、RSA1024bitを使用していましたが、既に脆弱とされているので、RSA2048bitに対応しました。一応RSA1024bitを用いた通信もできるようになっています。

それから、CREA2014で使うかどうかは分かりませんが、圧縮機能も付けました。テキストデータの圧縮はある程度意味があるでしょうが、既に圧縮されている動画データを更に圧縮しようとしても意味がないので、通信内容によって使い分ける必要があるでしょう。

 

    public class CommunicationApparatus
    {
        private readonly NetworkStream ns;
        private readonly RijndaelManaged rm;

        public CommunicationApparatus(NetworkStream _ns, RijndaelManaged _rm)
        {
            ns = _ns;
            rm = _rm;
        }

        public CommunicationApparatus(NetworkStream _ns) : this(_ns, null) { }

        public byte ReadBytes()
        {
            return ReadBytesInnner(false);
        }

        public void WriteBytes(byte data)
        {
            WriteBytesInner(data, false);
        }

        public byte ReadCompressedBytes()
        {
            return ReadBytesInnner(true);
        }

        public void WriteCompreddedBytes(byte data)
        {
            WriteBytesInner(data, true);
        }

        private byte ReadBytesInnner(bool isCompressed)
        {
            //最初の4バイトは本来のデータの長さ
            byte
dataLengthBytes = new byte[4];
            ns.Read(dataLengthBytes, 0, 4);
            int dataLength = BitConverter.ToInt32(dataLengthBytes, 0);

            if (dataLength == 0)
                return new byte { };

            //次の32バイトは受信データのハッシュ(破損検査用)
            byte hash = new byte[32];
            ns.Read(hash, 0, 32);

            //次の4バイトは受信データの長さ
            byte readDataLengthBytes = new byte[4];
            ns.Read(readDataLengthBytes, 0, 4);
            int readDataLength = BitConverter.ToInt32(readDataLengthBytes, 0);

            byte readData = null;

            using (MemoryStream ms = new MemoryStream())
            {
                byte buffer = new byte[1024];

                while (true)
                {
                    int byteSize = ns.Read(buffer, 0, buffer.Length);
                    ms.Write(buffer, 0, byteSize);

                    if (ms.Length >= readDataLength)
                        break;
                }

                readData = ms.ToArray();
            }

            if (!hash.BytesEquals(new SHA256Managed().ComputeHash(readData)))
                throw new Exception("receive_data_corrupt"); //対応済

            byte data = new byte[dataLength];

            using (MemoryStream ms = new MemoryStream(readData))
                if (rm != null)
                {
                    using (ICryptoTransform icf = rm.CreateDecryptor(rm.Key, rm.IV))
                    using (CryptoStream cs = new CryptoStream(ms, icf, CryptoStreamMode.Read))
                        if (isCompressed)
                            using (DeflateStream ds = new DeflateStream(cs, CompressionMode.Decompress))
                                ds.Read(data, 0, data.Length);
                        else
                            cs.Read(data, 0, data.Length);

                    return data;
                }
                else
                    if (isCompressed)
                    {
                        using (DeflateStream ds = new DeflateStream(ms, CompressionMode.Decompress))
                            ds.Read(data, 0, data.Length);

                        return data;
                    }
                    else
                        return readData;
        }

        private void WriteBytesInner(byte data, bool isCompressed)
        {
            ns.Write(BitConverter.GetBytes(data.Length), 0, 4);

            if (data.Length == 0)
                return;

            byte writeData = null;

            using (MemoryStream ms = new MemoryStream())
            {
                if (rm != null)
                    using (ICryptoTransform icf = rm.CreateEncryptor(rm.Key, rm.IV))
                    using (CryptoStream cs = new CryptoStream(ms, icf, CryptoStreamMode.Write))
                        if (isCompressed)
                            using (DeflateStream ds = new DeflateStream(cs, CompressionMode.Compress))
                            {
                                ds.Write(data, 0, data.Length);
                                ds.Flush();
                            }
                        else
                        {
                            cs.Write(data, 0, data.Length);
                            cs.FlushFinalBlock();
                        }
                else
                    if (isCompressed)
                        using (DeflateStream ds = new DeflateStream(ms, CompressionMode.Compress))
                        {
                            ds.Write(data, 0, data.Length);
                            ds.Flush();
                        }
                    else
                    {
                        ms.Write(data, 0, data.Length);
                        ms.Flush();
                    }

                writeData = ms.ToArray();
            }

            ns.Write(new SHA256Managed().ComputeHash(writeData), 0, 32);
            ns.Write(BitConverter.GetBytes(writeData.Length), 0, 4);
            ns.Write(writeData, 0, writeData.Length);
        }
    }

    public enum RsaKeySize { rsa1024, rsa2048 }

    public class Client
    {
        private readonly string ipAddress;
        private readonly ushort port;
        private readonly string privateRsaParameter;
        private readonly RsaKeySize keySize;
        private readonly Action<CommunicationApparatus, IPEndPoint> protocolProcess;
        private readonly int receiveTimeout;
        private readonly int sendTimeout;
        private readonly int receiveBufferSize;
        private readonly int sendBufferSize;

        private Socket client;

        public Client(string _ipAddress, ushort _port, RsaKeySize _keySize, string _privateRsaParameter, Action<CommunicationApparatus, IPEndPoint> _protocolProcess, int _receiveTimeout, int _sendTimeout, int _receiveBufferSize, int _sendBufferSize)
        {
            ipAddress = _ipAddress;
            port = _port;
            keySize = _keySize;
            privateRsaParameter = _privateRsaParameter;
            protocolProcess = _protocolProcess;
            receiveTimeout = _receiveTimeout;
            sendTimeout = _sendTimeout;
            receiveBufferSize = _receiveBufferSize;
            sendBufferSize = _sendBufferSize;
        }

        public Client(string _ipAddress, ushort _port, Action<CommunicationApparatus, IPEndPoint> _protocolProcess, int _receiveTimeout, int _sendTimeout, int _receiveBufferSize, int _sendBufferSize) : this(_ipAddress, _port, RsaKeySize.rsa2048, null, _protocolProcess, _receiveTimeout, _sendTimeout, _receiveBufferSize, _sendBufferSize) { }

        public Client(string _ipAddress, ushort _port, RsaKeySize _keySize, string _privateRsaParameter, Action<CommunicationApparatus, IPEndPoint> _protocolProcess) : this(_ipAddress, _port, _keySize, _privateRsaParameter, _protocolProcess, 30000, 30000, 8192, 8192) { }

        public Client(string _ipAddress, ushort _port, Action<CommunicationApparatus, IPEndPoint> _protocolProcess) : this(_ipAddress, _port, RsaKeySize.rsa2048, null, _protocolProcess, 30000, 30000, 8192, 8192) { }

        public event EventHandler Connected = delegate { };
        public event EventHandler<Exception> Errored = delegate { };

        public void StartClient()
        {
            if (client != null)
                throw new InvalidOperationException("client_already_started"); //対応済

            this.StartTask*1
                    {
                        RijndaelManaged rm = null;

                        if (privateRsaParameter != null)
                        {
                            RSACryptoServiceProvider rsacsp = new RSACryptoServiceProvider();
                            rsacsp.FromXmlString(privateRsaParameter);

                            if *2
                                throw new Exception("client_rsa_key_size");

                            RSAParameters rsaParameters = rsacsp.ExportParameters(true);
                            byte modulus = rsaParameters.Modulus;
                            byte
exponent = rsaParameters.Exponent;

                            ns.Write(modulus, 0, modulus.Length);
                            ns.Write(exponent, 0, exponent.Length);

                            RSAPKCS1KeyExchangeDeformatter rsapkcs1ked = new RSAPKCS1KeyExchangeDeformatter(rsacsp);

                            byte encryptedKey = keySize == RsaKeySize.rsa1024 ? new byte[128] : new byte[256];
                            byte
encryptedIv = keySize == RsaKeySize.rsa1024 ? new byte[128] : new byte[256];

                            ns.Read(encryptedKey, 0, encryptedKey.Length);
                            ns.Read(encryptedIv, 0, encryptedIv.Length);

                            rm = new RijndaelManaged();
                            rm.Padding = PaddingMode.Zeros;
                            rm.Key = rsapkcs1ked.DecryptKeyExchange(encryptedKey);
                            rm.IV = rsapkcs1ked.DecryptKeyExchange(encryptedIv);
                        }

                        Connected(this, EventArgs.Empty);

                        protocolProcess(new CommunicationApparatus(ns, rm), (IPEndPoint)client.RemoteEndPoint);
                    }
                }
                catch (Exception ex)
                {
                    this.RaiseError("client_socket".GetLogMessage(), 5, ex);

                    EndClient();

                    Errored(this, ex);
                }
            }, "client", string.Empty);
        }

        public void EndClient()
        {
            if (client == null)
                throw new InvalidOperationException("client_not_started"); //対応済

            try
            {
                if (client.Connected)
                    client.Shutdown(SocketShutdown.Both);
                client.Close();
            }
            catch (Exception ex)
            {
                this.RaiseError("client_socket".GetLogMessage(), 5, ex);
            }
        }
    }

    public class Listener
    {
        private readonly ushort port;
        private readonly bool isEncrypted;
        private readonly RsaKeySize keySize;
        private readonly Action<CommunicationApparatus, IPEndPoint> protocolProcess;
        private readonly int receiveTimeout;
        private readonly int sendTimeout;
        private readonly int receiveBufferSize;
        private readonly int sendBufferSize;
        private readonly int backlog;

        private Socket listener = null;
        private readonly object lobject = new object();
        private List<Socket> clients = new List<Socket>();

        private Listener(ushort _port, bool _isEncrypted, RsaKeySize _keySize, Action<CommunicationApparatus, IPEndPoint> _protocolProcess, int _receiveTimeout, int _sendTimeout, int _receiveBufferSize, int _sendBufferSize, int _backlog)
        {
            port = _port;
            isEncrypted = _isEncrypted;
            keySize = _keySize;
            protocolProcess = _protocolProcess;
            receiveTimeout = _receiveTimeout;
            sendTimeout = _sendTimeout;
            receiveBufferSize = _receiveBufferSize;
            sendBufferSize = _sendBufferSize;
            backlog = _backlog;
        }

        public Listener(ushort _port, RsaKeySize _keySize, Action<CommunicationApparatus, IPEndPoint> _protocolProcess, int _receiveTimeout, int _sendTimeout, int _receiveBufferSize, int _sendBufferSize, int _backlog) : this(_port, true, _keySize, _protocolProcess, _receiveTimeout, _sendTimeout, _receiveBufferSize, _sendBufferSize, _backlog) { }

        public Listener(ushort _port, Action<CommunicationApparatus, IPEndPoint> _protocolProcess, int _receiveTimeout, int _sendTimeout, int _receiveBufferSize, int _sendBufferSize, int _backlog) : this(_port, false, RsaKeySize.rsa2048, _protocolProcess, _receiveTimeout, _sendTimeout, _receiveBufferSize, _sendBufferSize, _backlog) { }

        public Listener(ushort _port, RsaKeySize _keySize, Action<CommunicationApparatus, IPEndPoint> _protocolProcess) : this(_port, true, _keySize, _protocolProcess, 30000, 30000, 8192, 8192, 100) { }

        public Listener(ushort _port, Action<CommunicationApparatus, IPEndPoint> _protocolProcess) : this(_port, false, RsaKeySize.rsa2048, _protocolProcess, 30000, 30000, 8192, 8192, 100) { }

        public event EventHandler Connected = delegate { };
        public event EventHandler<Exception> Errored = delegate { };
        public event EventHandler ClientConnected = delegate { };
        public event EventHandler<Exception> ClientErrored = delegate { };

        public void StartListener()
        {
            if (listener != null)
                throw new InvalidOperationException("listener_already_started"); //対応済

            this.StartTask*3;
                    listener.Listen(backlog);

                    while (true)
                    {
                        Socket client = listener.Accept();
                        lock (lobject)
                            clients.Add(client);

                        StartClient(client);
                    }
                }
                catch (Exception ex)
                {
                    this.RaiseError("listner_socket".GetLogMessage(), 5, ex);

                    EndListener();

                    Errored(this, ex);
                }
            }, "listener", string.Empty);
        }

        public void EndListener()
        {
            if (listener == null)
                throw new InvalidOperationException("listener_not_started"); //対応済

            try
            {
                listener.Close();

                lock (lobject)
                    foreach (var client in clients)
                    {
                        if (client.Connected)
                            client.Shutdown(SocketShutdown.Both);
                        client.Close();
                    }
            }
            catch (Exception ex)
            {
                this.RaiseError("listner_socket".GetLogMessage(), 5, ex);
            }
        }

        private void StartClient(Socket client)
        {
            this.StartTask*4
                    {
                        RijndaelManaged rm = null;

                        if (isEncrypted)
                        {
                            byte modulus = keySize == RsaKeySize.rsa1024 ? new byte[128] : new byte[256];
                            byte
exponent = new byte[3];

                            ns.Read(modulus, 0, modulus.Length);
                            ns.Read(exponent, 0, exponent.Length);

                            RSACryptoServiceProvider rsacsp = new RSACryptoServiceProvider();
                            RSAParameters rsaParameters = new RSAParameters();
                            rsaParameters.Modulus = modulus;
                            rsaParameters.Exponent = exponent;
                            rsacsp.ImportParameters(rsaParameters);

                            RSAPKCS1KeyExchangeFormatter rsapkcs1kef = new RSAPKCS1KeyExchangeFormatter(rsacsp);

                            rm = new RijndaelManaged();
                            rm.Padding = PaddingMode.Zeros;

                            byte encryptedKey = rsapkcs1kef.CreateKeyExchange(rm.Key);
                            byte
encryptedIv = rsapkcs1kef.CreateKeyExchange(rm.IV);

                            ns.Write(encryptedKey, 0, encryptedKey.GetLength(0));
                            ns.Write(encryptedIv, 0, encryptedIv.GetLength(0));
                        }

                        ClientConnected(this, EventArgs.Empty);

                        protocolProcess(new CommunicationApparatus(ns, rm), (IPEndPoint)client.RemoteEndPoint);
                    }
                }
                catch (Exception ex)
                {
                    this.RaiseError("listner_socket".GetLogMessage(), 5, ex);

                    EndClient(client);

                    ClientErrored(this, ex);
                }
            }, "listener_client", string.Empty);
        }

        private void EndClient(Socket client)
        {
            try
            {
                lock (lobject)
                    clients.Remove(client);

                if (client.Connected)
                    client.Shutdown(SocketShutdown.Both);
                client.Close();
            }
            catch (Exception ex)
            {
                this.RaiseError("listner_socket".GetLogMessage(), 5, ex);
            }
        }
    }

*1:) =>
            {
                try
                {
                    client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
                    client.Connect(IPAddress.Parse(ipAddress), port);
                    client.ReceiveTimeout = receiveTimeout;
                    client.SendTimeout = sendTimeout;
                    client.ReceiveBufferSize = receiveBufferSize;
                    client.SendBufferSize = sendBufferSize;

                    using (NetworkStream ns = new NetworkStream(client

*2:keySize == RsaKeySize.rsa1024 && rsacsp.KeySize != 1024) || (keySize == RsaKeySize.rsa2048 && rsacsp.KeySize != 2048

*3:) =>
            {
                try
                {
                    listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
                    listener.Bind(new IPEndPoint(IPAddress.Any, port

*4:) =>
            {
                try
                {
                    client.ReceiveTimeout = receiveTimeout;
                    client.SendTimeout = sendTimeout;
                    client.ReceiveBufferSize = receiveBufferSize;
                    client.SendBufferSize = sendBufferSize;

                    using (NetworkStream ns = new NetworkStream(client