-
Notifications
You must be signed in to change notification settings - Fork 2
/
web_dataset.py
157 lines (127 loc) · 4.83 KB
/
web_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import glob
import os
from typing import List
import torch
import torch.multiprocessing
import webdataset as wds
from einops import rearrange
from huggingface_hub import HfFileSystem, get_token, hf_hub_url
from torch.utils.data import IterableDataset
from torchvision import transforms
def count_examples(dataset_dir: str) -> int:
return len(glob.glob(os.path.join(dataset_dir, "*.jpeg")))
def split_len(split: str) -> int:
return {"train": 1270669, "validation": 4040, "test": 4588}[split]
def actions_to_one_hot(actions: List[int]) -> torch.tensor:
"""
Converts actions to one-hot encoded vectors using torch.scatter_.
Handles -1 values by creating a zero vector (no action).
Args:
actions (List[int]): Actions to convert, can contain -1 for no action.
Returns:
torch.tensor: One-hot encoded actions of shape (len(actions), 9).
"""
actions_tensor = torch.tensor(actions)
one_hot = torch.zeros(len(actions), 25, dtype=torch.long)
mask = actions_tensor >= 0 # Changed from != -1 for clarity
if mask.any():
one_hot[torch.arange(len(actions))[mask], actions_tensor[mask]] = 1
return one_hot.float()
class SplitImages(object):
"""
Splits a sequence image file into 5 images
"""
def __call__(self, image: torch.tensor) -> torch.tensor:
"""
Applies the transformation to the sequence of images.
Args:
image (np.array): Sequence of images. Size [3, 270, 2400]
Returns:
torch.tensor: Transformed sequence of images. Size (5, 270, 480, 3)
"""
return rearrange(image, "c h (n w) -> n c h w", n=5, c=3, h=270, w=480)
class ImageDataset(IterableDataset):
def __init__(
self,
split: str,
return_actions: bool = False,
):
"""
INIT
Args:
split (str): Split of the dataset. One of ["train", "validation", "test"]
return_actions (bool): If True, return one-hot encoded actions
"""
self.return_actions = return_actions
self.split = split
self.current_index = 0 # Add index tracking
# Define split patterns with correct paths
splits = {
"train": "**/train/*.tar",
"validation": "dev/00000.tar", # Updated path
"test": "**/test/**/*.tar", # Updated path to specifically look for tar files
}
# Set up HuggingFace filesystem and get URLs
fs = HfFileSystem()
pattern = f"hf://datasets/Iker/GTAV-Driving-Dataset/{splits[split]}"
files = [fs.resolve_path(path) for path in fs.glob(pattern)]
if not files:
raise ValueError(
f"No files found for split '{split}' with pattern {pattern}"
)
urls = [
hf_hub_url(file.repo_id, file.path_in_repo, repo_type="dataset")
for file in files
]
# Join URLs with double colon and add curl command
urls = (
f"pipe:curl -s -L --retry 3 --retry-delay 1 --retry-all-errors "
f"-H 'Authorization:Bearer {get_token()}' {'::'.join(urls)}"
)
transform = transforms.Compose(
[transforms.ToTensor(), SplitImages(), transforms.Resize((360, 640))]
)
# Create WebDataset with proper image decoding
self.dataset = (
wds.WebDataset(
urls,
handler=wds.warn_and_continue,
shardshuffle=True,
nodesplitter=wds.shardlists.split_by_worker,
empty_check=False,
resampled=True,
)
.shuffle(1000) # Add shuffle buffer
.decode("pil") # Decode as PIL Image
.to_tuple("jpg", "cls", "json")
.map(
lambda x: (transform(x[0]), x[1], x[2])
) # Pass through all three values
)
print(f"Loaded dataset for {split} split with {len(files)} tar files")
def __len__(self):
"""
Returns the number of examples in the dataset.
"""
return split_len(self.split)
def __iter__(self):
"""
Returns a sample from the dataset.
"""
for img, cls, json_data in self.dataset:
if self.return_actions:
actions = actions_to_one_hot(json_data["actions_int"])
yield {"video": img, "actions": actions}
else:
yield {"video": img}
def __getstate__(self):
"""Save dataset state"""
return {
"current_index": self.current_index,
"split": self.split,
"return_actions": self.return_actions,
}
def __setstate__(self, state):
"""Restore dataset state"""
self.__init__(split=state["split"], return_actions=state["return_actions"])
self.current_index = state["current_index"]