-
-
Notifications
You must be signed in to change notification settings - Fork 91
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: HuggingFace model downloader (#45)
* LLamaSharp integration with tests * changed test project name. added projects to solution * added HF Downloader, used it in LLamaSharp tests * added missing reference
- Loading branch information
Showing
6 changed files
with
233 additions
and
2 deletions.
There are no files selected for viewing
31 changes: 31 additions & 0 deletions
31
src/libs/Providers/LangChain.Providers.HuggingFace/Downloader/HttpClientExtensions.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
namespace LangChain.Providers.Downloader; | ||
|
||
internal static class HttpClientExtensions | ||
{ | ||
public static async Task DownloadAsync(this HttpClient client, string requestUri, Stream destination, IProgress<double> progress = null, CancellationToken cancellationToken = default) | ||
{ | ||
// Get the http headers first to examine the content length | ||
using (var response = await client.GetAsync(requestUri, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false)) | ||
{ | ||
var contentLength = response.Content.Headers.ContentLength; | ||
|
||
using (var download = await response.Content.ReadAsStreamAsync().ConfigureAwait(false)) | ||
{ | ||
|
||
// Ignore progress reporting when no progress reporter was | ||
// passed or when the content length is unknown | ||
if (progress == null || !contentLength.HasValue) | ||
{ | ||
await download.CopyToAsync(destination).ConfigureAwait(false); | ||
return; | ||
} | ||
|
||
// Convert absolute progress (bytes downloaded) into relative progress (0% - 100%) | ||
var relativeProgress = new Progress<long>(totalBytes => progress.Report((double)totalBytes / contentLength.Value)); | ||
// Use extension method to report progress while downloading | ||
await download.CopyToAsync(destination, 81920, relativeProgress, cancellationToken).ConfigureAwait(false); | ||
progress.Report(1); | ||
} | ||
} | ||
} | ||
} |
68 changes: 68 additions & 0 deletions
68
src/libs/Providers/LangChain.Providers.HuggingFace/Downloader/HuggingFaceModelDownloader.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
namespace LangChain.Providers.Downloader; | ||
|
||
/// <summary> | ||
/// A downloader for HuggingFace models | ||
/// </summary> | ||
public class HuggingFaceModelDownloader | ||
{ | ||
public static HuggingFaceModelDownloader Instance { get; } = new HuggingFaceModelDownloader(); | ||
|
||
|
||
/// <summary> | ||
/// The HttpClient used to download the models | ||
/// </summary> | ||
public HttpClient HttpClient { get; set; } = new HttpClient(); | ||
|
||
/// <summary> | ||
/// The default storage path for the models | ||
/// </summary> | ||
public static string DefaultStoragePath => | ||
Path.Combine( | ||
Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData), | ||
"LangChain", "CSharp", "Models"); | ||
|
||
private async Task DownloadModel(string url, string path, CancellationToken? cancellationToken=null) | ||
{ | ||
var client = HttpClient; | ||
|
||
using (var file = new FileStream(path, FileMode.Create, FileAccess.Write, FileShare.None)) | ||
{ | ||
using ProgressBar progress = new ProgressBar(); | ||
|
||
await client.DownloadAsync(url, file, progress, cancellationToken??CancellationToken.None).ConfigureAwait(false); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Downloads a model from HuggingFace with caching and return path to it | ||
/// </summary> | ||
public async Task<string> GetModel(string repository, string fileName, string version="master", string storagePath = null) | ||
{ | ||
storagePath ??= HuggingFaceModelDownloader.DefaultStoragePath; | ||
var repositoryPath = Path.Combine(storagePath, repository); | ||
if (!Directory.Exists(repositoryPath)) | ||
{ | ||
Directory.CreateDirectory(repositoryPath); | ||
} | ||
|
||
var modelPath = Path.Combine(repositoryPath, version, fileName); | ||
var directory = Path.GetDirectoryName(modelPath); | ||
if (!Directory.Exists(directory)) | ||
{ | ||
Directory.CreateDirectory(directory); | ||
} | ||
var downloadMarkerPath = modelPath + ".hfdownload"; // to verify if the download is complete | ||
if (!File.Exists(modelPath)||File.Exists(downloadMarkerPath)) | ||
{ | ||
File.WriteAllText(downloadMarkerPath, ""); | ||
File.Delete(modelPath); | ||
Console.WriteLine($"No model file found. Downloading..."); | ||
var downloadUrl = $"https://huggingface.co/{repository}/resolve/{version}/{fileName}"; | ||
await DownloadModel(downloadUrl, modelPath).ConfigureAwait(false); | ||
File.Delete(downloadMarkerPath); | ||
} | ||
|
||
|
||
return modelPath; | ||
} | ||
} |
102 changes: 102 additions & 0 deletions
102
src/libs/Providers/LangChain.Providers.HuggingFace/Downloader/ProgressBar.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
using System.Text; | ||
|
||
namespace LangChain.Providers.Downloader; | ||
|
||
/// <summary> | ||
/// An ASCII progress bar | ||
/// </summary> | ||
internal class ProgressBar : IDisposable, IProgress<double> | ||
{ | ||
private const int blockCount = 10; | ||
private readonly TimeSpan animationInterval = TimeSpan.FromSeconds(1.0 / 8); | ||
private const string animation = @"|/-\"; | ||
|
||
private readonly Timer timer; | ||
|
||
private double currentProgress = 0; | ||
private string currentText = string.Empty; | ||
private bool disposed = false; | ||
private int animationIndex = 0; | ||
|
||
public ProgressBar() | ||
{ | ||
timer = new Timer(TimerHandler); | ||
|
||
// A progress bar is only for temporary display in a console window. | ||
// If the console output is redirected to a file, draw nothing. | ||
// Otherwise, we'll end up with a lot of garbage in the target file. | ||
if (!Console.IsOutputRedirected) | ||
{ | ||
ResetTimer(); | ||
} | ||
} | ||
|
||
public void Report(double value) | ||
{ | ||
// Make sure value is in [0..1] range | ||
value = Math.Max(0, Math.Min(1, value)); | ||
Interlocked.Exchange(ref currentProgress, value); | ||
} | ||
|
||
private void TimerHandler(object state) | ||
{ | ||
lock (timer) | ||
{ | ||
if (disposed) return; | ||
|
||
int progressBlockCount = (int)(currentProgress * blockCount); | ||
int percent = (int)(currentProgress * 100); | ||
string text = string.Format("[{0}{1}] {2,3}% {3}", | ||
new string('#', progressBlockCount), new string('-', blockCount - progressBlockCount), | ||
percent, | ||
animation[animationIndex++ % animation.Length]); | ||
UpdateText(text); | ||
|
||
ResetTimer(); | ||
} | ||
} | ||
|
||
private void UpdateText(string text) | ||
{ | ||
// Get length of common portion | ||
int commonPrefixLength = 0; | ||
int commonLength = Math.Min(currentText.Length, text.Length); | ||
while (commonPrefixLength < commonLength && text[commonPrefixLength] == currentText[commonPrefixLength]) | ||
{ | ||
commonPrefixLength++; | ||
} | ||
|
||
// Backtrack to the first differing character | ||
StringBuilder outputBuilder = new StringBuilder(); | ||
outputBuilder.Append('\b', currentText.Length - commonPrefixLength); | ||
|
||
// Output new suffix | ||
outputBuilder.Append(text.Substring(commonPrefixLength)); | ||
|
||
// If the new text is shorter than the old one: delete overlapping characters | ||
int overlapCount = currentText.Length - text.Length; | ||
if (overlapCount > 0) | ||
{ | ||
outputBuilder.Append(' ', overlapCount); | ||
outputBuilder.Append('\b', overlapCount); | ||
} | ||
|
||
Console.Write(outputBuilder); | ||
currentText = text; | ||
} | ||
|
||
private void ResetTimer() | ||
{ | ||
timer.Change(animationInterval, TimeSpan.FromMilliseconds(-1)); | ||
} | ||
|
||
public void Dispose() | ||
{ | ||
lock (timer) | ||
{ | ||
disposed = true; | ||
UpdateText(string.Empty); | ||
} | ||
} | ||
|
||
} |
28 changes: 28 additions & 0 deletions
28
src/libs/Providers/LangChain.Providers.HuggingFace/Downloader/StreamExtensions.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
namespace LangChain.Providers.Downloader; | ||
|
||
internal static class StreamExtensions | ||
{ | ||
public static async Task CopyToAsync(this Stream source, Stream destination, int bufferSize, IProgress<long> progress = null, CancellationToken cancellationToken = default) | ||
{ | ||
if (source == null) | ||
throw new ArgumentNullException(nameof(source)); | ||
if (!source.CanRead) | ||
throw new ArgumentException("Has to be readable", nameof(source)); | ||
if (destination == null) | ||
throw new ArgumentNullException(nameof(destination)); | ||
if (!destination.CanWrite) | ||
throw new ArgumentException("Has to be writable", nameof(destination)); | ||
if (bufferSize < 0) | ||
throw new ArgumentOutOfRangeException(nameof(bufferSize)); | ||
|
||
var buffer = new byte[bufferSize]; | ||
long totalBytesRead = 0; | ||
int bytesRead; | ||
while ((bytesRead = await source.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false)) != 0) | ||
{ | ||
await destination.WriteAsync(buffer, 0, bytesRead, cancellationToken).ConfigureAwait(false); | ||
totalBytesRead += bytesRead; | ||
progress?.Report(totalBytesRead); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters