用Java NIO模拟HTTPS

发布于:2025-04-13 ⋅ 阅读:(23) ⋅ 点赞:(0)

HTTPS流程 

名词解释:
    R1:随机数1 R2:随机数2 R3:随机数3 publicKey:公钥 privateKey:私钥

要提供https服务,服务端需要安装数字证书,在(TCP建立连接之后)TLS握手时发给客户端,客户端验证证书,证书包含公钥。
step1 
客户端 client hello + R1  
服务端 server hello + R2 + publicKey
step2
客户端 R3 publicKey加密 预主密钥
服务端 privateKey解密得到R3
step3
客户端与服务端使用相同的对称密钥算法生成会话密钥
客户端 R3 + R1 + R2 -> 生成会话密钥  主密钥
服务端 R3 + R1 + R2 -> 生成会话密钥
step4
正式通信 对称密钥(会话密钥)加密数据

HttpsServer


import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.security.NoSuchAlgorithmException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

public class HttpsServer {

    private static final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);

    private static Map<String, Object> keyMap;

    private static final byte[] randomList = new byte[3];

    private static volatile Map<SocketChannel, Boolean> channelBooleanMap = new ConcurrentHashMap<>();

    public static void main(String[] args) throws Exception {

        keyMap = init();

        scheduler.scheduleAtFixedRate(() -> {
            channelBooleanMap.forEach((k, v) -> {
                if (!v) {
                    channelBooleanMap.remove(k);
                }
            });
        }, 0, 1, TimeUnit.SECONDS);

        startServer();
    }

    public static Map<String, Object> init() throws NoSuchAlgorithmException {
        Map<String, Object> map = EncryptUtil.initKey();
        return Collections.unmodifiableMap(map);
    }

    public static void startServer() throws Exception {
        ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
        serverSocketChannel.configureBlocking(false);
        serverSocketChannel.bind(new InetSocketAddress(8080));
        Selector selector = Selector.open();
        serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT);
        System.out.println("服务器监听....");

        while (true) {
            int select = selector.select();
            if (select > 0) {
                Set<SelectionKey> selectionKeys = selector.selectedKeys();
                Iterator<SelectionKey> iterator = selectionKeys.iterator();
                while (iterator.hasNext()) {
                    SelectionKey key = iterator.next();
                    iterator.remove();
                    if (key.isValid()) {
                        if (key.isAcceptable()) {
                            ServerSocketChannel channel = (ServerSocketChannel) key.channel();
                            SocketChannel socketChannel = channel.accept();
                            System.out.println("连接:" + socketChannel.getRemoteAddress());
                            socketChannel.configureBlocking(false);

                            channelBooleanMap.put(socketChannel, true);

                            // 为每个连接到的channel分配一个缓冲区,channel间互相隔离
                            ByteBuffer buffer = ByteBuffer.allocate(8 * 1024);
                            socketChannel.register(selector, SelectionKey.OP_READ, buffer);

                        }
                        if (key.isReadable()) {
                            SocketChannel channel = (SocketChannel) key.channel();
                            try {
                                ByteBuffer buffer = ByteBuffer.allocate(1024);
                                byte b;
                                int read, count = 0;
                                while (true) {
                                    if ((read = channel.read(buffer)) > 0) {
                                        count += read;
                                        System.out.println("count:" + count);
                                    }
                                    if (count == 14) {
                                        buffer.flip();
                                        break;
                                    }
                                    if (count == 65) {
                                        buffer.flip();
                                        break;
                                    }
                                    if (count == 32) { // 正式通信
                                        buffer.flip();
                                        break;
                                    }
                                }
                                b = buffer.get(0);
                                if (b == 13) { // 第一次同步消息
                                    byte[] array = buffer.array();
                                    System.out.println(new String(array, 1, 12));
                                    byte r1 = array[13]; // 随机数1 客户端发给服务端
                                    System.out.println("随机数1:" + r1); // 大小端
                                    randomList[0] = r1;
                                } else if (b == 64) { // 第二次同步消息
                                    byte[] array = buffer.array();
                                    byte[] data = new byte[b];
                                    System.arraycopy(array, 1, data, 0, b);
                                    byte[] bytes = EncryptUtil.decryptByPrivateKey(data, EncryptUtil.getPrivateKey(keyMap));
                                    System.out.println("随机数3=" + bytes[0]);
                                    randomList[2] = bytes[0];

                                    System.out.println("randomList:" + Arrays.toString(randomList));
                                    // 生成会话密钥
                                    byte[] sessionKey = EncryptUtil.hmacSHA256(EncryptUtil.HmacSHA256_key.getBytes(), randomList);
                                    SetCache.add(sessionKey);
                                    System.out.println("会话密钥:" + Arrays.toString(sessionKey));
                                } else { // 正式通信
                                    byte[] array = new byte[32];
                                    buffer.get(array);
                                    System.out.println("array=" + Arrays.toString(array));
                                    if (Arrays.compare(array, SetCache.get()) == 0) {
                                        System.out.println("会话密钥验证成功");
                                    } else {
                                        System.out.println("会话密钥验证失败");
                                    }
                                }

                            } catch (Exception e) {
                                channelBooleanMap.put(channel, false);
                                key.cancel();
                                channel.close();
                                System.out.println("有连接关闭...");
                                break;
                            }
                            System.out.println("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++");
                            // 注册事件:可能会触发多余的写事件
                            channel.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE);
//                            if (flag) { // 通信结束标识
//                                channel.register(selector, SelectionKey.OP_READ);
//                            } else {
//                                channel.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE);
//                            }
                        }
                        if (key.isWritable()) {
                            System.out.println("触发写事件....");
                            SocketChannel channel = (SocketChannel) key.channel();

                            ByteBuffer buffer = ByteBuffer.allocate(1024);
                            String serverHello = "Server Hello";
                            buffer.put(serverHello.getBytes());

                            byte b = (byte) new Random().nextInt(Byte.MAX_VALUE);
                            randomList[1] = b; // 随机数2 服务端发送给客户端
                            buffer.put(b);

                            // 发送公钥给客户端
                            byte[] publicKey = EncryptUtil.getPublicKey(keyMap);
                            byte len = (byte) publicKey.length;
                            System.out.println("publicKey.length:" + len);
                            buffer.put(len);
                            buffer.put(publicKey);

                            // 注意:往channel中写缓冲区前,必须切换到读模式,否则无法触发读事件
                            buffer.flip();
                            if (!channelBooleanMap.get(channel)) {
                                System.out.println("通道已关闭...");
                                channel.register(selector, key.interestOps() & ~SelectionKey.OP_WRITE);
                                break;
                            }

                            channel.write(buffer);

//                            channel.socket().getOutputStream().flush();

//                            channel.write(ByteBuffer.wrap(serverHello.getBytes()));
//                            channel.write(ByteBuffer.wrap(new byte[]{b}));

                            System.out.println(Arrays.toString(buffer.array()));
                            System.out.println("随机数2:" + b);

                            channel.register(selector, key.interestOps() & ~SelectionKey.OP_WRITE);
                        }
                    }
                }
            }
        }
    }
}

HttpsClient


import java.io.IOException;
import java.io.InputStream;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.util.*;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;

public class HttpsClient {

    private static final List<byte[]> key = new ArrayList<>();

    private static final byte[] randomList = new byte[3];

    private static final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);

    public static void main(String[] args) throws Exception {

        Socket socket = new Socket("localhost", 8080);
        socket.setSoLinger(true, 0); // 设置关闭连接时的等待时间
//        socket.setReuseAddress(true); // 设置端口重用

        String clientHello = "Client Hello";

        int anInt = new Random().nextInt(Byte.MAX_VALUE);
        System.out.println("Client 随机数1:" + anInt);
        randomList[0] = (byte) anInt;

        socket.getOutputStream().write(new byte[]{13});
        socket.getOutputStream().write(clientHello.getBytes());
        socket.getOutputStream().write(anInt);

        InputStream inputStream = socket.getInputStream();
        byte[] buffer = new byte[12];
        int read, count = 0;
        while (count < 12) {
            read = inputStream.read(buffer);
            count += read;
        }
        String cmd = new String(buffer);
        System.out.println("Server " + cmd);

        // 读取第二个随机数
        int read1 = inputStream.read();
        System.out.println("Server 随机数2:" + read1);
        randomList[1] = (byte) read1;

        // 读取公钥
        int len = inputStream.read();
        System.out.println("publicKey len: " + len);
        byte[] publicKey = new byte[len];
        int count2 = 0;
        while (count2 < len) {
            int read2 = inputStream.read(publicKey);
            count2 += read2;
        }
        key.add(publicKey);

        System.out.println("输入任何字符开始第二次通信...");
        System.in.read();

        // 客户端生成第三个随机数
        int r3 = new Random().nextInt(Byte.MAX_VALUE);
        byte[] bytes = {(byte) r3};
        randomList[2] = bytes[0];
        System.out.println("随机数3=" + Arrays.toString(bytes));
        byte[] data = EncryptUtil.encryptByPublicKey(bytes, publicKey);

        socket.getOutputStream().write(data.length); // 64
        socket.getOutputStream().write(data);

        System.out.println("randomList:" + Arrays.toString(randomList));
        // 生成会话密钥
        byte[] sessionKey = EncryptUtil.hmacSHA256(EncryptUtil.HmacSHA256_key.getBytes(), randomList);
        SetCache.add(sessionKey);
        System.out.println("会话密钥:" + Arrays.toString(sessionKey));
        System.out.println("密钥长度:" + SetCache.get().length);

        System.out.println("开始正式通信...");
        System.out.println("发送密钥....");
        socket.getOutputStream().write(SetCache.get());

        System.out.println("end...");
        socket.close();
    }

    public void test() throws IOException {
        SocketChannel channel = SocketChannel.open();
        channel.configureBlocking(false);
        Selector selector = Selector.open();
        channel.register(selector, SelectionKey.OP_CONNECT);
        channel.connect(new InetSocketAddress("localhost", 8080));

        while (true) {
            int select = selector.select();
            if (select > 0) {
                Iterator<SelectionKey> iterator = selector.selectedKeys().iterator();
                while (iterator.hasNext()) {
                    SelectionKey key = iterator.next();
                    iterator.remove();
                    if (key.isConnectable()) {
                        if (channel.finishConnect()) {
                            System.out.println("客户端连接成功...");
                            channel.register(selector, SelectionKey.OP_WRITE | SelectionKey.OP_READ);
                        }
                    }
                    if (key.isWritable()) {
                        System.out.println("Client send...");
                        SocketChannel channel1 = (SocketChannel) key.channel();
                        ByteBuffer buffer = ByteBuffer.allocate(16);
                        String ClientHello = "Client Hello";
                        int r1 = new Random().nextInt(100);
                        buffer.put(ClientHello.getBytes());
                        buffer.putInt(r1);
                        channel1.write(buffer);
                        channel1.register(selector, key.interestOps() & ~SelectionKey.OP_WRITE);
                        System.out.println("Client send end...");
                    }
                    if (key.isReadable()) {
                        System.out.println("Client receive...");
                        SocketChannel channel1 = (SocketChannel) key.channel();
                        ByteBuffer buffer = ByteBuffer.allocate(16);
                        channel1.read(buffer);
                        byte[] array = buffer.array();
                        System.out.println(new String(array, 0, 12));
                        System.out.println(new String(array, 12, 16));
                        channel1.register(selector, SelectionKey.OP_WRITE);
                    }
                }
            }
        }
    }
}

EncryptUtil

import javax.crypto.*;
import javax.crypto.spec.SecretKeySpec;
import java.security.*;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.util.HashMap;
import java.util.Map;

public final class EncryptUtil {

    // 非对称加密算法
    private static final String KEY_ALGORITHM = "RSA";
    // 公钥
    private static final String PUBLIC_KEY = "PUBLIC_KEY";
    // 私钥
    private static final String PRIVATE_KEY = "PRIVATE_KEY";
    // RSA密钥长度 默认1024 必须为64的倍数
    private static final int KEY_SIZE = 512;

    public static final String HmacSHA256_key = "HmacSHA256_key";

    public static Map<String, Object> initKey() throws NoSuchAlgorithmException {
        // 实例化密钥对生成器
        KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance(KEY_ALGORITHM);
        keyPairGenerator.initialize(KEY_SIZE);
        KeyPair keyPair = keyPairGenerator.generateKeyPair();
        RSAPublicKey publicKey = (RSAPublicKey) keyPair.getPublic();
        RSAPrivateKey privateKey = (RSAPrivateKey) keyPair.getPrivate();
        Map<String, Object> keyMap = new HashMap<>();
        keyMap.put(PUBLIC_KEY, publicKey);
        keyMap.put(PRIVATE_KEY, privateKey);
        return keyMap;
    }

    public static byte[] getPublicKey(Map<String, Object> keyMap) throws Exception {
        Key key = (Key) keyMap.get(PUBLIC_KEY);
        return key.getEncoded();
    }


    public static byte[] getPrivateKey(Map<String, Object> keyMap) throws NoSuchAlgorithmException {
        Key key = (Key) keyMap.get(PRIVATE_KEY);
        return key.getEncoded();
    }

    public static byte[] encryptByPublicKey(byte[] data, byte[] key) throws NoSuchAlgorithmException, InvalidKeySpecException, NoSuchPaddingException, InvalidKeyException, IllegalBlockSizeException, BadPaddingException {
        // 取得公钥
        X509EncodedKeySpec x509EncodedKeySpec = new X509EncodedKeySpec(key);
        KeyFactory keyFactory = KeyFactory.getInstance(KEY_ALGORITHM);
        PublicKey publicKey = keyFactory.generatePublic(x509EncodedKeySpec);
        // 对数据加密
        Cipher cipher = Cipher.getInstance(keyFactory.getAlgorithm());
        cipher.init(Cipher.ENCRYPT_MODE, publicKey);
        return cipher.doFinal(data);
    }

    public static byte[] decryptByPrivateKey(byte[] data, byte[] key) throws Exception {
        // 取得私钥
        PKCS8EncodedKeySpec pkcs8EncodedKeySpec = new PKCS8EncodedKeySpec(key);
        KeyFactory keyFactory = KeyFactory.getInstance(KEY_ALGORITHM);
        // 生成私钥
        PrivateKey privateKey = keyFactory.generatePrivate(pkcs8EncodedKeySpec);
        // 对数据解密
        Cipher cipher = Cipher.getInstance(keyFactory.getAlgorithm());
        cipher.init(Cipher.DECRYPT_MODE, privateKey);
        return cipher.doFinal(data);
    }

    /**
     * HmacSHA256算法,返回的结果始终是32位
     *
     * @param key     加密的键,可以是任何数据
     * @param content 待加密的内容
     * @return 加密后的内容
     * @throws Exception
     */
    public static byte[] hmacSHA256(byte[] key, byte[] content) throws Exception {
        Mac hmacSha256 = Mac.getInstance("HmacSHA256");
        hmacSha256.init(new SecretKeySpec(key, 0, key.length, "HmacSHA256"));
        byte[] hmacSha256Bytes = hmacSha256.doFinal(content);
        return hmacSha256Bytes;
    }
}

SetCache

public class SetCache {
    private static final byte[] cache = new byte[32];

    public static void add(byte[] key) {
        System.arraycopy(key, 0, cache, 0, 32);
    }

    public static byte[] get() {
        return cache;
    }
}