Skip to content

Commit

Permalink
fix is_orderby_partition may hang & move tdm to match model proto
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji committed Oct 11, 2024
1 parent 23eeadc commit 0da86ce
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 17 deletions.
50 changes: 42 additions & 8 deletions tzrec/datasets/odps_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,46 @@ def _calc_slice_position(
slice_count: int,
batch_size: int,
drop_redundant_bs_eq_one: bool,
pre_total_remain: int = 0,
) -> Tuple[int, int]:
"""Calc table read position according to the slice information."""
size = int(row_count / slice_count)
split_point = row_count % slice_count
"""Calc table read position according to the slice information.
Args:
row_count (int): table total row count.
slice_id (int): worker id.
slice_count (int): total worker number.
batch_size (int): batch_size.
drop_redundant_bs_eq_one (bool): drop last redundant batch with batch_size
equal one to prevent train_eval hung.
pre_total_remain (int): remaining total count in pre-table is
insufficient to meet the batch_size requirement for each worker.
Return:
start (int): start row position in table.
end (int): start row position in table.
total_remain (int): remaining total count in curr-table is
insufficient to meet the batch_size requirement for each worker.
"""
pre_remain_size = int(pre_total_remain / slice_count)
pre_remain_split_point = pre_total_remain % slice_count

size = int((row_count + pre_total_remain) / slice_count)
split_point = (row_count + pre_total_remain) % slice_count
if slice_id < split_point:
start = slice_id * (size + 1)
end = start + (size + 1)
else:
start = split_point * (size + 1) + (slice_id - split_point) * size
end = start + size

real_start = (
start - pre_remain_size * slice_id - min(pre_remain_split_point, slice_id)
)
real_end = (
end
- pre_remain_size * (slice_id + 1)
- min(pre_remain_split_point, slice_id + 1)
)
# when (end - start) % bz = 1 on some workers and
# (end - start) % bz = 0 on other workers, train_eval will hang
if (
Expand All @@ -169,8 +199,9 @@ def _calc_slice_position(
and (end - start) % batch_size == 1
and size % batch_size == 0
):
end = end - 1
return start, end
real_end = real_end - 1
split_point = 0
return real_start, real_end, (size % batch_size) * slice_count + split_point


def _read_rows_arrow_with_retry(
Expand Down Expand Up @@ -200,20 +231,23 @@ def _reader_iter(
batch_size: int,
drop_redundant_bs_eq_one: bool,
) -> Iterator[pa.RecordBatch]:
for sess_req in sess_reqs:
num_sess = len(sess_reqs)
remain_row_count = 0
for i, sess_req in enumerate(sess_reqs):
while True:
scan_resp = client.get_read_session(sess_req)
if scan_resp.session_status == SessionStatus.INIT:
time.sleep(1)
continue
break
start, end = _calc_slice_position(
start, end, remain_row_count = _calc_slice_position(
# pyre-ignore [6]
scan_resp.record_count,
worker_id,
num_workers,
batch_size,
drop_redundant_bs_eq_one,
drop_redundant_bs_eq_one if i == num_sess - 1 else False,
remain_row_count,
)

offset = 0
Expand Down
22 changes: 21 additions & 1 deletion tzrec/datasets/odps_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch import distributed as dist
from torch.utils.data import DataLoader

from tzrec.datasets.odps_dataset import OdpsDataset, OdpsWriter
from tzrec.datasets.odps_dataset import OdpsDataset, OdpsWriter, _calc_slice_position
from tzrec.features.feature import FgMode, create_features
from tzrec.protos import data_pb2, feature_pb2
from tzrec.utils import test_util
Expand Down Expand Up @@ -50,6 +50,26 @@ def tearDown(self):
if self.o is not None:
self.o.delete_table(f"test_odps_dataset_{self.test_suffix}", if_exists=True)

def test_calc_slice_position(self):
num_tables = 81
num_workers = 8
batch_size = 10
remain_row_counts = [0] * num_workers
worker_row_counts = [0] * num_workers
for i in range(num_tables):
for j in range(num_workers):
start, end, remain_row_counts[j] = _calc_slice_position(
row_count=81,
slice_id=j,
slice_count=num_workers,
batch_size=batch_size,
drop_redundant_bs_eq_one=True if i == num_tables - 1 else False,
pre_total_remain=remain_row_counts[j],
)
worker_row_counts[j] += end - start
self.assertTrue(np.all(np.ceil(np.array(worker_row_counts) / batch_size) == 82))
self.assertEqual(sum(worker_row_counts), num_tables * 81 - 1)

@parameterized.expand([[False], [True]])
@unittest.skipIf("ODPS_CONFIG_FILE_PATH" not in os.environ, "odps config not found")
def test_odps_dataset(self, is_orderby_partition):
Expand Down
4 changes: 2 additions & 2 deletions tzrec/models/tdm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tzrec.features.feature import create_features
from tzrec.models.tdm import TDM
from tzrec.protos import feature_pb2, loss_pb2, model_pb2, module_pb2, tower_pb2
from tzrec.protos.models import rank_model_pb2
from tzrec.protos.models import match_model_pb2
from tzrec.utils.test_util import TestGraphType, create_test_model, init_parameters


Expand Down Expand Up @@ -85,7 +85,7 @@ def test_tdm(self, graph_type) -> None:
]
model_config = model_pb2.ModelConfig(
feature_groups=feature_groups,
tdm=rank_model_pb2.TDM(
tdm=match_model_pb2.TDM(
multiwindow_din=tower_pb2.MultiWindowDIN(
windows_len=[1, 2, 5], attn_mlp=module_pb2.MLP(hidden_units=[8, 4])
),
Expand Down
6 changes: 6 additions & 0 deletions tzrec/protos/models/match_model.proto
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
syntax = "proto2";
package tzrec.protos;

import "tzrec/protos/module.proto";
import "tzrec/protos/tower.proto";

enum Similarity {
Expand Down Expand Up @@ -33,3 +34,8 @@ message DSSMV2 {
// use in batch items as negative items.
optional bool in_batch_negative = 6 [default = false];
}

message TDM {
required MultiWindowDIN multiwindow_din =1;
required MLP final = 2;
}
5 changes: 0 additions & 5 deletions tzrec/protos/models/rank_model.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,3 @@ message MultiTowerDIN {
repeated DINTower din_towers = 2;
required MLP final = 3;
}

message TDM {
required MultiWindowDIN multiwindow_din =1;
required MLP final = 2;
}
2 changes: 1 addition & 1 deletion tzrec/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.5.6"
__version__ = "0.5.7"

0 comments on commit 0da86ce

Please sign in to comment.