From 9f6b6b6634c2160db724f52ab65bc6cddf00f12a Mon Sep 17 00:00:00 2001 From: David Gaag Date: Wed, 14 Aug 2024 13:43:58 -0600 Subject: [PATCH] fix: alter read to prevent out of memory error --- .../main/java/com/tflite/TfliteModule.java | 38 +++++++++++++++---- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/android/src/main/java/com/tflite/TfliteModule.java b/android/src/main/java/com/tflite/TfliteModule.java index 7d36c029..ea1d0fe4 100644 --- a/android/src/main/java/com/tflite/TfliteModule.java +++ b/android/src/main/java/com/tflite/TfliteModule.java @@ -82,14 +82,14 @@ public static byte[] fetchByteDataFromUrl(String url) throws Exception { throw new IOException("File does not exist or is not readable: " + path); } - try (FileInputStream fis = new FileInputStream(file); - ByteArrayOutputStream bos = new ByteArrayOutputStream()) { - byte[] buffer = new byte[8192]; // Larger buffer for efficiency - int bytesRead; - while ((bytesRead = fis.read(buffer)) != -1) { - bos.write(buffer, 0, bytesRead); - } - return bos.toByteArray(); + // Check if the file has a .tflite extension + if (!file.getName().toLowerCase().endsWith(".tflite")) { + throw new SecurityException("Only .tflite files are allowed"); + } + + // Read the file + try (FileInputStream stream = new FileInputStream(file)) { + return getLocalFileBytes(stream, file); } catch (IOException e) { Log.e(NAME, "Error reading file: " + path, e); throw new RuntimeException("Failed to read file: " + path, e); @@ -155,4 +155,26 @@ public boolean install() { } private static native boolean nativeInstall(long jsiPtr, CallInvokerHolderImpl jsCallInvoker); + + private static byte[] getLocalFileBytes(InputStream stream, File file) throws IOException { + long fileSize = file.length(); + + if (fileSize > Integer.MAX_VALUE) { + throw new IOException("File is too large to read into memory"); + } + + byte[] data = new byte[(int) fileSize]; + + int bytesRead = 0; + int chunk; + while (bytesRead < fileSize && (chunk = stream.read(data, bytesRead, (int)fileSize - bytesRead)) != -1) { + bytesRead += chunk; + } + + if (bytesRead != fileSize) { + throw new IOException("Could not completely read file " + file.getName()); + } + + return data; + } }