Skip to content

Zoo features, agent team #1305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: v3_develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions .github/workflows/zoo_helper.workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
build-zoo-helper-linux-x86_64:
runs-on: ubuntu-latest
container:
image: almalinux:8
image: almalinux:8.10-20240528
steps:
- name: Cache .hunter folder
uses: actions/cache@v3
Expand Down Expand Up @@ -39,7 +39,7 @@ jobs:
dnf install -y pkgconf-pkg-config bison autoconf libtool libXi-devel libXtst-devel cmake zip perl-core python39
dnf install -y libXrandr-devel libX11-devel libXft-devel libXext-devel flex systemd-devel
dnf install -y gcc-c++ automake libtool-ltdl-devel wget
wget https://www.nasm.us/pub/nasm/releasebuilds/2.16.01/nasm-2.16.01.tar.gz && tar -xzf nasm-2.16.01.tar.gz && cd nasm-2.16.01 && ./configure && make && make install && cd .. # install nasm - build from source
wget https://github.com/netwide-assembler/nasm/archive/refs/tags/nasm-2.15.04.tar.gz && tar -xzf nasm-2.15.04.tar.gz && cd nasm-nasm-2.15.04 && ./autogen.sh && ./configure && make && make install && cd .. # install nasm - build from source
pip3 install jinja2

- name: Configure project
Expand All @@ -65,7 +65,7 @@ jobs:
build-zoo-helper-linux-arm64:
runs-on: [self-hosted, linux, ARM64]
container:
image: arm64v8/almalinux:8
image: arm64v8/almalinux:8.10-20240528
# Mount local hunter cache directory, instead of transfering to Github and back
volumes:
- /.hunter:/github/home/.hunter
Expand Down Expand Up @@ -95,7 +95,6 @@ jobs:
dnf install -y pkgconf-pkg-config bison autoconf libtool libXi-devel libXtst-devel cmake git zip perl-core python39
dnf install -y libXrandr-devel libX11-devel libXft-devel libXext-devel flex systemd-devel
dnf install -y gcc-c++ automake libtool-ltdl-devel wget
wget https://www.nasm.us/pub/nasm/releasebuilds/2.16.01/nasm-2.16.01.tar.gz && tar -xzf nasm-2.16.01.tar.gz && cd nasm-2.16.01 && ./configure && make && make install && cd .. # install nasm - build from source
pip3 install jinja2
pip3 install ninja # ninja is needed for cmake on arm64

Expand Down
2 changes: 2 additions & 0 deletions bindings/python/src/modelzoo/ZooBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ void ZooBindings::bind(pybind11::module& m, void* pCallstack) {
py::arg("useCached") = true,
py::arg("cacheDirectory") = "",
py::arg("apiKey") = "",
py::arg("progressFormat") = "none",
DOC(dai, getModelFromZoo));

m.def("downloadModelsFromZoo",
downloadModelsFromZoo,
py::arg("path"),
py::arg("cacheDirectory") = "",
py::arg("apiKey") = "",
py::arg("progressFormat") = "none",
DOC(dai, downloadModelsFromZoo));

// Bind NNModelDescription
Expand Down
7 changes: 5 additions & 2 deletions include/depthai/modelzoo/Zoo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,14 @@ struct NNModelDescription {
* DEPTHAI_ZOO_CACHE_PATH environment variable and uses that if set, otherwise the default value is used (see getDefaultCachePath).
* @param apiKey: API key for the model zoo, default is "". If apiKey is set to "", this function checks the DEPTHAI_ZOO_API_KEY environment variable and uses
* that if set. Otherwise, no API key is used.
* @param progressFormat: Format to use for progress output (possible values: pretty, json, none), default is "none"
* @return std::string: Path to the model in cache
*/
std::string getModelFromZoo(const NNModelDescription& modelDescription,
bool useCached = true,
const std::string& cacheDirectory = "",
const std::string& apiKey = "");
const std::string& apiKey = "",
const std::string& progressFormat = "none");
Comment on lines +89 to +96
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason to not go with an enum here instead of a string?


/**
* @brief Helper function allowing one to download all models specified in yaml files in the given path and store them in the cache directory
Expand All @@ -101,9 +103,10 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription,
* DEPTHAI_ZOO_CACHE_PATH environment variable and uses that if set, otherwise the default is used (see getDefaultCachePath).
* @param apiKey: API key for the model zoo, default is "". If apiKey is set to "", this function checks the DEPTHAI_ZOO_API_KEY environment variable and uses
* that if set. Otherwise, no API key is used.
* @param progressFormat: Format to use for progress output (possible values: pretty, json, none), default is "none"
* @return bool: True if all models were downloaded successfully, false otherwise
*/
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory = "", const std::string& apiKey = "");
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory = "", const std::string& apiKey = "", const std::string& progressFormat = "none");

std::ostream& operator<<(std::ostream& os, const NNModelDescription& modelDescription);

Expand Down
172 changes: 163 additions & 9 deletions src/modelzoo/Zoo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
#include <cctype>
#include <filesystem>
#include <iostream>
#include <memory>
#include <mutex>
#include <nlohmann/json.hpp>
#include <nlohmann/json_fwd.hpp>

#include "utility/Environment.hpp"
#include "utility/Logging.hpp"
Expand All @@ -12,6 +15,7 @@

#ifdef DEPTHAI_ENABLE_CURL
#include <cpr/api.h>
#include <cpr/cprtypes.h>
#include <cpr/parameters.h>
#include <cpr/status_codes.h>
#endif
Expand Down Expand Up @@ -96,8 +100,11 @@ class ZooManager {

/**
* @brief Download model from model zoo
*
* @param responseJson: JSON with download links
* @param cprCallback: Progress callback
*/
void downloadModel(const nlohmann::json& responseJson);
void downloadModel(const nlohmann::json& responseJson, std::unique_ptr<cpr::ProgressCallback> cprCallback);

/**
* @brief Return path to model in cache
Expand Down Expand Up @@ -328,7 +335,7 @@ nlohmann::json ZooManager::fetchModelDownloadLinks() {
return responseJson;
}

void ZooManager::downloadModel(const nlohmann::json& responseJson) {
void ZooManager::downloadModel(const nlohmann::json& responseJson, std::unique_ptr<cpr::ProgressCallback> cprCallback) {
// Extract download links from response
auto downloadLinks = responseJson["download_links"].get<std::vector<std::string>>();
auto downloadHash = responseJson["hash"].get<std::string>();
Expand All @@ -346,7 +353,7 @@ void ZooManager::downloadModel(const nlohmann::json& responseJson) {

// Download all files and store them in cache folder
for(const auto& downloadLink : downloadLinks) {
cpr::Response downloadResponse = cpr::Get(cpr::Url(downloadLink));
cpr::Response downloadResponse = cpr::Get(cpr::Url(downloadLink), *cprCallback);
if(checkIsErrorModelDownload(downloadResponse)) {
removeModelCacheFolder();
throw std::runtime_error(generateErrorMessageModelDownload(downloadResponse));
Expand Down Expand Up @@ -395,7 +402,145 @@ std::string ZooManager::loadModelFromCache() const {
return std::filesystem::absolute(folderFiles[0]).string();
}

std::string getModelFromZoo(const NNModelDescription& modelDescription, bool useCached, const std::string& cacheDirectory, const std::string& apiKey) {
class CprCallback {
public:
virtual ~CprCallback() = default;
CprCallback(const std::string& modelName) : modelName(modelName) {}

virtual void cprCallback(
cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) = 0;

virtual std::unique_ptr<cpr::ProgressCallback> getCprProgressCallback() {
return std::make_unique<cpr::ProgressCallback>(
[this](cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) {
this->cprCallback(downloadTotal, downloadNow, uploadTotal, uploadNow, userdata);
return true;
});
}

protected:
std::string modelName;
};

class JsonCprCallback : public CprCallback {
constexpr static long long PRINT_INTERVAL_MS = 100;

public:
JsonCprCallback(const std::string& modelName) : CprCallback(modelName) {
startTime = std::chrono::steady_clock::time_point::min();
}

void print(long downloadTotal, long downloadNow, const std::string& modelName) {
nlohmann::json json = {
{"download_total", downloadTotal},
{"download_now", downloadNow},
{"model_name", modelName},
};
std::cout << json.dump() << std::endl;
}

void cprCallback(
cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) override {
(void)uploadTotal;
(void)uploadNow;
(void)userdata;

bool firstCall = startTime == std::chrono::steady_clock::time_point::min();
if(firstCall || downloadTotal == 0) {
startTime = std::chrono::steady_clock::now();
}

bool shouldPrint = std::chrono::steady_clock::now() - startTime > std::chrono::milliseconds(PRINT_INTERVAL_MS) || this->downloadTotal != downloadTotal;

if(shouldPrint) {
print(downloadTotal, downloadNow, modelName);
startTime = std::chrono::steady_clock::now();
}

this->downloadTotal = downloadTotal;
this->downloadNow = downloadNow;
}

~JsonCprCallback() override {
if(downloadTotal != 0) {
print(downloadTotal, downloadNow, modelName);
}
}

private:
long downloadTotal = 0;
long downloadNow = 0;
std::chrono::steady_clock::time_point startTime;
};

class PrettyCprCallback : public CprCallback {
public:
PrettyCprCallback(const std::string& modelName) : CprCallback(modelName), finalProgressPrinted(false) {}

void cprCallback(
cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) override {
(void)uploadTotal;
(void)uploadNow;
(void)userdata;

if(finalProgressPrinted) return;

if(downloadTotal > 0) {
float progress = static_cast<float>(downloadNow) / downloadTotal;
int barWidth = 50;
int pos = static_cast<int>(barWidth * progress);

std::cout << "\rDownloading " << modelName << " [";
for(int i = 0; i < barWidth; ++i) {
if(i < pos)
std::cout << "=";
else if(i == pos)
std::cout << ">";
else
std::cout << " ";
}
std::cout << "] " << std::fixed << std::setprecision(3) << progress * 100.0f << "% " << downloadNow / 1024.0f / 1024.0f << "/"
<< downloadTotal / 1024.0f / 1024.0f << " MB";

if(downloadNow == downloadTotal) {
std::cout << std::endl;
finalProgressPrinted = true;
} else {
std::cout << "\r";
std::cout.flush();
}
}
}

private:
bool finalProgressPrinted;
};

class NoneCprCallback : public CprCallback {
public:
NoneCprCallback(const std::string& modelName) : CprCallback(modelName) {}

void cprCallback(cpr::cpr_off_t, cpr::cpr_off_t, cpr::cpr_off_t, cpr::cpr_off_t, intptr_t) override {
// Do nothing
}
};

std::unique_ptr<CprCallback> getCprCallback(const std::string& format, const std::string& name) {
if(format == "json") {
return std::make_unique<JsonCprCallback>(name);
} else if(format == "pretty") {
return std::make_unique<PrettyCprCallback>(name);
} else if(format == "none") {
return std::make_unique<NoneCprCallback>(name);
}
throw std::runtime_error("Invalid format: " + format);
}

std::string getModelFromZoo(const NNModelDescription& modelDescription,
bool useCached,
const std::string& cacheDirectory,
const std::string& apiKey,
const std::string& progressFormat) {
// Check if model description is valid
if(!modelDescription.check()) throw std::runtime_error("Invalid model description:\n" + modelDescription.toString());

Expand Down Expand Up @@ -466,9 +611,12 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription, bool use
// Create cache folder
zooManager.createCacheFolder();

// Create download progress callback
std::unique_ptr<CprCallback> cprCallback = getCprCallback(progressFormat, modelDescription.globalMetadataEntryName.size() > 0 ? modelDescription.globalMetadataEntryName : modelDescription.model);

// Download model
logger::info("Downloading model from model zoo");
zooManager.downloadModel(responseJson);
zooManager.downloadModel(responseJson, cprCallback->getCprProgressCallback());

// Store model as yaml in the cache folder
std::string yamlPath = combinePaths(zooManager.getModelCacheFolderPath(cacheDirectory), "model.yaml");
Expand All @@ -479,7 +627,7 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription, bool use
return modelPath;
}

bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory, const std::string& apiKey) {
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory, const std::string& apiKey, const std::string& progressFormat) {
logger::info("Downloading models from zoo");
// Make sure 'path' exists
if(!std::filesystem::exists(path)) throw std::runtime_error("Path does not exist: " + path);
Expand Down Expand Up @@ -507,7 +655,7 @@ bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDire
try {
logger::info("Downloading model [{} / {}]: {}", i + 1, models.size(), modelName);
auto modelDescription = NNModelDescription::fromYamlFile(modelName, path);
getModelFromZoo(modelDescription, true, cacheDirectory, apiKey);
getModelFromZoo(modelDescription, true, cacheDirectory, apiKey, progressFormat);
logger::info("Downloaded model [{} / {}]: {}", i + 1, models.size(), modelName);
numSuccess++;
} catch(const std::exception& e) {
Expand Down Expand Up @@ -546,18 +694,24 @@ std::string ZooManager::getGlobalMetadataFilePath() const {

#else

std::string getModelFromZoo(const NNModelDescription& modelDescription, bool useCached, const std::string& cacheDirectory, const std::string& apiKey) {
std::string getModelFromZoo(const NNModelDescription& modelDescription,
bool useCached,
const std::string& cacheDirectory,
const std::string& apiKey,
const std::string& progressFormat) {
(void)modelDescription;
(void)useCached;
(void)cacheDirectory;
(void)apiKey;
(void)progressFormat;
throw std::runtime_error("getModelFromZoo requires libcurl to be enabled. Please recompile DepthAI with libcurl enabled.");
}

bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory, const std::string& apiKey) {
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory, const std::string& apiKey, const std::string& progressFormat) {
(void)path;
(void)cacheDirectory;
(void)apiKey;
(void)progressFormat;
throw std::runtime_error("downloadModelsFromZoo requires libcurl to be enabled. Please recompile DepthAI with libcurl enabled.");
}

Expand Down
22 changes: 15 additions & 7 deletions src/modelzoo/zoo_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ int main(int argc, char* argv[]) {
const std::string DEFAULT_DOWNLOAD_ENDPOINT = dai::modelzoo::getDownloadEndpoint();
program.add_argument("--download_endpoint").default_value(DEFAULT_DOWNLOAD_ENDPOINT).help("Endpoint to use for downloading models");

const std::string FORMAT_DEFAULT = "pretty";
program.add_argument("--format").default_value(FORMAT_DEFAULT).help("Format to use for output (possible values: pretty, json)");

program.add_argument("--verbose").default_value(false).implicit_value(true).help("Verbose output");

// Parse arguments
Expand All @@ -50,9 +53,10 @@ int main(int argc, char* argv[]) {
auto apiKey = program.get<std::string>("--api_key");
auto healthEndpoint = program.get<std::string>("--health_endpoint");
auto downloadEndpoint = program.get<std::string>("--download_endpoint");
auto format = program.get<std::string>("--format");

bool verbose = program.get<bool>("--verbose");
if(!dai::utility::isEnvSet("DEPTHAI_LEVEL") && verbose) {
if(!dai::utility::isEnvSet("DEPTHAI_LEVEL") && verbose && format == "pretty") {
dai::Logging::getInstance().logger.set_level(spdlog::level::info);
}

Expand All @@ -61,19 +65,23 @@ int main(int argc, char* argv[]) {
dai::modelzoo::setDownloadEndpoint(downloadEndpoint);

// Print arguments
std::cout << "Downloading models defined in yaml files in folder: " << yamlFolder << std::endl;
std::cout << "Downloading models into cache folder: " << cacheFolder << std::endl;
if(!apiKey.empty()) {
std::cout << "Using API key: " << apiKey << std::endl;
if(format == "pretty") {
std::cout << "Downloading models defined in yaml files in folder: " << yamlFolder << std::endl;
std::cout << "Downloading models into cache folder: " << cacheFolder << std::endl;
if(!apiKey.empty()) {
std::cout << "Using API key: " << apiKey << std::endl;
}
}

// Download models
bool success = dai::downloadModelsFromZoo(yamlFolder, cacheFolder, apiKey);
bool success = dai::downloadModelsFromZoo(yamlFolder, cacheFolder, apiKey, format);
if(!success) {
std::cerr << "Failed to download all models from " << yamlFolder << std::endl;
return EXIT_FAILURE;
}

std::cout << "Successfully downloaded all models from " << yamlFolder << std::endl;
if(format == "pretty") {
std::cout << "Successfully downloaded all models from " << yamlFolder << std::endl;
}
return EXIT_SUCCESS;
}
Loading