基于Java的高并发多线程分片断点下载

首先直接看测试情况:

单线程下载72MB文件

image-20220621112741769

7线程并发分片下载72MB文件:

image-20220621112811107

下载效率提高2-3倍,当然以上测试结果还和设备CPU核心数、网络带宽息息相关。

一、原理

分片下载主要核心来自于HTTP/1.1中的一个header:Range,主要作用是允许用户请求网络资源中的部分片段。基于此功能,我们可以结合Java多线程来开发一个多线程分片断点下载的辅助类,具体实现流程见文章剩下内容。

二、源代码

下面看一下全部源代码

import com.sccl.autojob.util.id.SystemClock;
import lombok.AccessLevel;
import lombok.Setter;
import lombok.experimental.Accessors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StringUtils;

import java.io.*;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.FutureTask;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * @Description 文件下载辅助类,支持分段并发下载,支持断点下载
 * @Author Huang Yongxiang
 * @Date 2022/05/27 15:56
 */
@Slf4j
public class FileDownloadHelper {
    /**
     * 文件元数据
     */
    private FileMetaData fileMetaData;
    /**
     * 请求方式
     */
    private String way;
    /**
     * 连接超时时长:ms
     */
    private int connectTimeout;
    /**
     * 读取数据超时时长:ms
     */
    private int readTimeout;
    /**
     * 分片数目
     */
    private int splitCount;
    /**
     * 完整数据的连接对象
     */
    private HttpURLConnection connection;

    private FileSplitFetchTask[] fetchTasks;

    private FileDownloadHelper() {
    }

    public static Builder builder() {
        return new Builder();
    }

    private long getLength() {
        try {
            if (fileMetaData.length != -1) {
                return fileMetaData.length;
            }
            HttpURLConnection connection = getConnection();
            if (connection.getResponseCode() == HttpURLConnection.HTTP_OK) {
                long length = Long.parseLong(connection.getHeaderField("Content-Length"));
                this.fileMetaData.length = length;
                return length;
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        return -1;
    }

    private HttpURLConnection getConnection() {
        try {
            if (connection == null) {
                HttpURLConnection connection = (HttpURLConnection) new URL(fileMetaData.url).openConnection();
                connection.setReadTimeout(readTimeout);
                connection.setConnectTimeout(connectTimeout);
                connection.setRequestMethod(way);
                this.connection = connection;
                return connection;
            }
        } catch (Exception e) {
            log.error("获取连接时发生异常:{}", e.getMessage());
        }
        return connection;
    }

    private HttpURLConnection getSplitConnection(long startPos, long endPos) {
        HttpURLConnection connection = null;
        try {
            connection = (HttpURLConnection) new URL(fileMetaData.url).openConnection();
            connection.setReadTimeout(readTimeout);
            connection.setConnectTimeout(connectTimeout);
            connection.setRequestMethod(way);
            String prop = "bytes=" + startPos + "-" + endPos;
            log.info("分片参数:{}", prop);
            connection.setRequestProperty("RANGE", prop);
        } catch (IOException e) {
            e.printStackTrace();
        }
        return connection;
    }

    private InputStream getInputStreamFromConnection(HttpURLConnection connection) {
        try {
            return connection.getInputStream();
        } catch (IOException e) {
            log.error("获取输入流时发生异常:{}", e.getMessage());
        }
        return null;
    }

    public FileMetaData getFileMetaData() {
        return fileMetaData;
    }

    private FileSplitFetchTask[] createFileSplitFetchTask() {
        FileSplitFetchTask[] fetchTasks = new FileSplitFetchTask[splitCount];
        long length = getLength();
        if (length == -1) {
            log.error("文件大小未知,服务创建分片任务失败");
            return fetchTasks;
        } else if (length == -2) {
            log.error("文件:{}不存在", fileMetaData.name);
            return fetchTasks;
        }
        int lastEndPos = 0;
        int startPos = 0;
        int endPos = 0;
        double averageLength = (length * 1.0) / splitCount;
        for (int i = 0; i < splitCount; i++) {
            if (lastEndPos != 0) {
                startPos = lastEndPos + 1;
            }
            endPos = startPos + (int) Math.ceil(averageLength);
            fetchTasks[i] = new FileSplitFetchTask(startPos, endPos, i);
            lastEndPos = endPos;
        }
        return fetchTasks;
    }


    public InputStream getInputStream() {
        return fileMetaData.getInputStream();
    }


    public boolean write() {
        try {
            if (fileMetaData.content != null && fileMetaData.content.size() > 0 && !StringUtils.isEmpty(fileMetaData.path)) {
                File file = new File(fileMetaData.path + File.separator + fileMetaData.name);
                FileOutputStream outputStream = new FileOutputStream(file);
                outputStream.write(fileMetaData.getContentAsByteArray());
                outputStream.flush();
                outputStream.close();
                return true;

            } else {
                log.error("没有指定写入路径,写入失败");
                return false;
            }
        } catch (Exception e) {
            log.error("写入时发生异常:{}", e.getMessage());
            e.printStackTrace();
        }
        return false;
    }

    public void stopDownload() {
        if (this.fetchTasks != null) {
            for (FileSplitFetchTask task : fetchTasks) {
                task.stop();
            }
        }
    }

    public void continueDownload() {
        if (this.fetchTasks != null) {
            for (FileSplitFetchTask task : fetchTasks) {
                task.goOn();
            }
        }
    }

    public void download() {
        int totalCount = 0;
        try {
            //构建分片任务对象
            FileSplitFetchTask[] fetchTasks = createFileSplitFetchTask();
            this.fetchTasks = fetchTasks;
            List<FutureTask<Integer>> futureTasks = new ArrayList<>();
            for (int i = 0; i < fetchTasks.length; i++) {
                FutureTask<Integer> futureTask = new FutureTask<>(fetchTasks[i]);
                futureTasks.add(futureTask);
                Thread thread = new Thread(futureTask);
                thread.setName(String.valueOf(i));
                //启动下载
                thread.start();
            }

            //阻塞等待所有线程下载完
            for (FutureTask<Integer> future : futureTasks) {
                totalCount += future.get();
            }
            //拼接内容
            for (FileSplitFetchTask task : fetchTasks) {
                byte[] cache = task.getContent();
                for (byte b : cache) {
                    fileMetaData.content.add(b);
                }
            }
            //写出
            if (!StringUtils.isEmpty(fileMetaData.path) && write()) {
                log.info("写出成功,路径:{}", fileMetaData.path + File.separator + fileMetaData.name);
            }
        } catch (Exception e) {
            e.printStackTrace();
            log.error("下载过程发生异常:{}", e.getMessage());
            return;
        }
        if (totalCount != fileMetaData.length) {
            log.warn("下载字节数:{}与实际字节数:{}不匹配", totalCount, fileMetaData.length);
        } else {
            //log.info("下载成功");
        }

    }

    public void clear() {
        this.fileMetaData = null;
        connection.disconnect();
        connection = null;
        System.gc();
    }

    /**
     * 下载文件的元数据
     */
    @Setter(AccessLevel.PRIVATE)
    private static class FileMetaData {
        /**
         * 地址
         */
        private String url;
        /**
         * 长度
         */
        private long length = -1;
        /**
         * 写入路径
         */
        private String path;
        /**
         * 文件名,包含后缀
         */
        private String name;
        /**
         * 后缀,不包含.
         */
        private String suffix;
        /**
         * 内容,二进制字列表
         */
        private List<Byte> content;

        private byte[] arrayContent;
        /**
         * 文件的输入流
         */
        private InputStream inputStream;

        public byte[] getContentAsByteArray() {
            if (arrayContent != null) {
                return arrayContent;
            }
            if (content != null) {
                int i = 0;
                byte[] holder = new byte[content.size()];
                for (Byte bt : content) {
                    holder[i++] = bt;
                }
                arrayContent = holder;
                return holder;

            }
            return new byte[]{};
        }

        public InputStream getInputStream() {
            if (inputStream == null) {
                inputStream = new ByteArrayInputStream(getContentAsByteArray());
            }
            return inputStream;
        }
    }

    @Setter
    @Accessors(chain = true)
    public static class Builder {
        /**
         * 地址
         */
        private String url;
        /**
         * 写入路径
         */
        private String path;
        /**
         * 文件名,包含后缀
         */
        private String name;
        /**
         * 请求方式
         */
        private String way = "get";
        /**
         * 连接超时时长:ms
         */
        private int connectTimeout = 5000;
        /**
         * 读取数据超时时长:ms
         */
        private int readTimeout = 5000;
        /**
         * 是否允许分片下载
         */
        private boolean allowSplitDownload = true;
        /**
         * 分片数目
         */
        private int splitCount = 3;

        public Builder setConnectTimeout(int connectTimeout, TimeUnit unit) {
            this.connectTimeout = (int) unit.toMillis(connectTimeout);
            return this;
        }

        public Builder setReadTimeout(int readTimeout, TimeUnit unit) {
            this.readTimeout = (int) unit.toMillis(readTimeout);
            return this;
        }

        public FileDownloadHelper build() {
            if (!check()) {
                throw new IllegalArgumentException("错误参数,请检查");
            }
            FileDownloadHelper fileDownloadHelper = new FileDownloadHelper();
            FileMetaData fileMetaData = new FileMetaData();
            if (StringUtils.isEmpty(name)) {
                int namePos = url.trim().lastIndexOf("/");
                name = url.substring(namePos + 1);
            }
            fileMetaData.setName(name);
            int pos = fileMetaData.name.lastIndexOf(".");
            if (pos != -1) {
                fileMetaData.setSuffix(fileMetaData.name.substring(pos));
            }
            fileMetaData.setPath(path);
            fileMetaData.setUrl(url.trim());
            fileMetaData.content = new ArrayList<>();
            if (way.trim().equalsIgnoreCase("get") || way.trim().equalsIgnoreCase("post")) {
                fileDownloadHelper.way = way.trim().toUpperCase();
            }
            fileDownloadHelper.connectTimeout = connectTimeout;
            fileDownloadHelper.readTimeout = readTimeout;
            fileDownloadHelper.fileMetaData = fileMetaData;
            fileDownloadHelper.splitCount = allowSplitDownload ? splitCount : 1;
            return fileDownloadHelper;
        }

        private boolean check() {
            boolean flag = url.lastIndexOf("/") != -1 || url.lastIndexOf(File.separator) != -1;
            return !StringUtils.isEmpty(url) && flag && splitCount > 0;
        }
    }

    private class FileSplitFetchTask implements Callable<Integer> {
        /**
         * 开始索引
         */
        private final long startPos;
        /**
         * 终止索引
         */
        private final long endPos;
        /**
         * 开始标志
         */
        private boolean isStart = false;
        /**
         * 结束标志
         */
        private boolean isOver = false;
        /**
         * 暂停标志
         */
        private AtomicBoolean isStop = new AtomicBoolean(false);
        /**
         * 线程号
         */
        private final int threadId;
        /**
         * 内容
         */
        private final byte[] content;
        /**
         * 分片请求
         */
        private final HttpURLConnection splitConnection;

        public FileSplitFetchTask(long startPos, long endPos, int threadId) {
            if (endPos < startPos) {
                throw new IllegalArgumentException("终止索引不得小于起始索引");
            }
            this.startPos = startPos;
            this.endPos = endPos;
            this.threadId = threadId;
            this.content = new byte[(int) (endPos - startPos + 1)];
            this.splitConnection = getSplitConnection(startPos, endPos);
        }

        public byte[] getContent() {
            if (isOver) {
                return content;
            } else {
                log.error("线程:{}正在拉取,无法获取", threadId);
                return null;
            }
        }

        public void stop() {
            log.info("线程:{}已暂停下载", threadId);
            this.isStop.set(true);
        }

        public void goOn() {
            log.info("线程:{}已继续下载", threadId);
            this.isStop.set(false);
        }

        public boolean isStart() {
            return isStart;
        }

        public boolean isOver() {
            return isOver;
        }

        public boolean isStop() {
            return isStop.get();
        }

        @Override
        public Integer call() throws Exception {
            long start = SystemClock.now();
            log.info("线程:{}下载开始", threadId);
            InputStream inputStream = getInputStreamFromConnection(splitConnection);
            if (inputStream == null) {
                throw new NullPointerException("线程:" + threadId + "下载失败,输入流为空");
            }
            int cache;
            try {
                isStart = true;
                int pos = 0;
                while (!isStop.get()) {
                    cache = inputStream.read();
                    if (cache != -1) {
                        content[pos++] = (byte) cache;
                    } else {
                        break;
                    }
                }
                log.info("线程:{}已下载完,共计:{}字节,共计用时:{}ms", threadId, pos, SystemClock.now() - start);
                isStop.set(true);
                isOver = true;
                inputStream.close();
                return pos;
            } catch (IOException e) {
                e.printStackTrace();
            }
            return 0;
        }
    }


}

三、源码讲解

模块说明

源码包含三个内部类,分别是FileMetaDataBuilderFileSplitFetchTaskFileMetaData是下载文件的元数据描述,包含地址、长度、写入路径、文件名、后缀以及文件的二进制内容和输入流;Builder是构建者模式中的构建者角色,用于构建FileDownloadHelper对象;FileSplitFetchTask是下载任务对象,供线程执行。

下载逻辑

用户启动下载时先访问接口获取文件长度,然后根据文件长度和分片长度创建好每个分片HttpUrlConnection对象,继而创建FileSplitFetchTask对象。一切就绪后创建线程执行每个FileSplitFetchTask,主线程异步阻塞等待子线程下载完成,并对下载后的字节数据进行组合。下载时是单个一个字节一个字节的下载,主线程可以操作FileSplitFetchTask对象,进行暂停和恢复操作。下载完成后如果构建时指定了路径,文件将会直接写入到指定路径。

四、使用

对于要下载的网络资源首先应该看其是否支持范围请求,具体方法是请求该网络资源所在URL地址,然后看返回的HTTP响应的头部是否包含请求头Accept-Ranges,并且如果该请求头的值不是none,则说明该资源支持范围请求,如下,Content-Length是该资源的完成大小。

Accept-Ranges: bytes
Content-Length: 146515
public static void main(String[] args) {
        long start = SystemClock.now();
        AtomicBoolean over = new AtomicBoolean(false);
        String[] urls = new String[]{"https://avatar-1309914555.cos.ap-chengdu.myqcloud.com/UU-4.27.0.exe"};
        for (String url : urls) {
            FileDownloadHelper downloadHelper =
FileDownloadHelper.builder().setUrl(url).setWay("get").setAllowSplitDownload(true).setSplitCount(10).setConnectTimeout(5, TimeUnit.SECONDS).setReadTimeout(60, TimeUnit.SECONDS).build();
            downloadHelper.download();
        }
        System.out.println("下载完成,总计用时:" + (SystemClock.now() - start) + "ms");
    }

以上代码是一个使用示列,示列网络资源是网易的mumu加速器,存放在腾讯云的对象存储中。使用时使用构建者模式配置相关参数,创建对象,然后直接调用download方法即可开始下载,更多细节可以阅读源码。

image-20220621112811107