Skip to content

Commit 119292c

Browse files
authored
[IR2Vec] Add out-of-place arithmetic operators to Embedding class (#145118)
This PR adds out-of-place arithmetic operators (`+`, `-`, `*`) to the `Embedding` class in IR2Vec, complementing the existing in-place operators (`+=`, `-=`, `*=`). Tests have been added to verify the functionality of these new operators. (Tracking issue - #141817)
1 parent efe0dea commit 119292c

File tree

3 files changed

+63
-4
lines changed

3 files changed

+63
-4
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,12 @@ struct Embedding {
107107
const std::vector<double> &getData() const { return Data; }
108108

109109
/// Arithmetic operators
110-
Embedding &operator+=(const Embedding &RHS);
111-
Embedding &operator-=(const Embedding &RHS);
112-
Embedding &operator*=(double Factor);
110+
LLVM_ABI Embedding &operator+=(const Embedding &RHS);
111+
LLVM_ABI Embedding operator+(const Embedding &RHS) const;
112+
LLVM_ABI Embedding &operator-=(const Embedding &RHS);
113+
LLVM_ABI Embedding operator-(const Embedding &RHS) const;
114+
LLVM_ABI Embedding &operator*=(double Factor);
115+
LLVM_ABI Embedding operator*(double Factor) const;
113116

114117
/// Adds Src Embedding scaled by Factor with the called Embedding.
115118
/// Called_Embedding += Src * Factor

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,27 +70,44 @@ inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
7070
// ==----------------------------------------------------------------------===//
7171
// Embedding
7272
//===----------------------------------------------------------------------===//
73-
7473
Embedding &Embedding::operator+=(const Embedding &RHS) {
7574
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
7675
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
7776
std::plus<double>());
7877
return *this;
7978
}
8079

80+
Embedding Embedding::operator+(const Embedding &RHS) const {
81+
Embedding Result(*this);
82+
Result += RHS;
83+
return Result;
84+
}
85+
8186
Embedding &Embedding::operator-=(const Embedding &RHS) {
8287
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
8388
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
8489
std::minus<double>());
8590
return *this;
8691
}
8792

93+
Embedding Embedding::operator-(const Embedding &RHS) const {
94+
Embedding Result(*this);
95+
Result -= RHS;
96+
return Result;
97+
}
98+
8899
Embedding &Embedding::operator*=(double Factor) {
89100
std::transform(this->begin(), this->end(), this->begin(),
90101
[Factor](double Elem) { return Elem * Factor; });
91102
return *this;
92103
}
93104

105+
Embedding Embedding::operator*(double Factor) const {
106+
Embedding Result(*this);
107+
Result *= Factor;
108+
return Result;
109+
}
110+
94111
Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
95112
assert(this->size() == Src.size() && "Vectors must have the same dimension");
96113
for (size_t Itr = 0; Itr < this->size(); ++Itr)

llvm/unittests/Analysis/IR2VecTest.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,18 @@ TEST(EmbeddingTest, ConstructorsAndAccessors) {
109109
}
110110
}
111111

112+
TEST(EmbeddingTest, AddVectorsOutOfPlace) {
113+
Embedding E1 = {1.0, 2.0, 3.0};
114+
Embedding E2 = {0.5, 1.5, -1.0};
115+
116+
Embedding E3 = E1 + E2;
117+
EXPECT_THAT(E3, ElementsAre(1.5, 3.5, 2.0));
118+
119+
// Check that E1 and E2 are unchanged
120+
EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
121+
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
122+
}
123+
112124
TEST(EmbeddingTest, AddVectors) {
113125
Embedding E1 = {1.0, 2.0, 3.0};
114126
Embedding E2 = {0.5, 1.5, -1.0};
@@ -120,6 +132,18 @@ TEST(EmbeddingTest, AddVectors) {
120132
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
121133
}
122134

135+
TEST(EmbeddingTest, SubtractVectorsOutOfPlace) {
136+
Embedding E1 = {1.0, 2.0, 3.0};
137+
Embedding E2 = {0.5, 1.5, -1.0};
138+
139+
Embedding E3 = E1 - E2;
140+
EXPECT_THAT(E3, ElementsAre(0.5, 0.5, 4.0));
141+
142+
// Check that E1 and E2 are unchanged
143+
EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
144+
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
145+
}
146+
123147
TEST(EmbeddingTest, SubtractVectors) {
124148
Embedding E1 = {1.0, 2.0, 3.0};
125149
Embedding E2 = {0.5, 1.5, -1.0};
@@ -137,6 +161,15 @@ TEST(EmbeddingTest, ScaleVector) {
137161
EXPECT_THAT(E1, ElementsAre(0.5, 1.0, 1.5));
138162
}
139163

164+
TEST(EmbeddingTest, ScaleVectorOutOfPlace) {
165+
Embedding E1 = {1.0, 2.0, 3.0};
166+
Embedding E2 = E1 * 0.5f;
167+
EXPECT_THAT(E2, ElementsAre(0.5, 1.0, 1.5));
168+
169+
// Check that E1 is unchanged
170+
EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
171+
}
172+
140173
TEST(EmbeddingTest, AddScaledVector) {
141174
Embedding E1 = {1.0, 2.0, 3.0};
142175
Embedding E2 = {2.0, 0.5, -1.0};
@@ -180,6 +213,12 @@ TEST(EmbeddingTest, AccessOutOfBounds) {
180213
EXPECT_DEATH(E[4] = 4.0, "Index out of bounds");
181214
}
182215

216+
TEST(EmbeddingTest, MismatchedDimensionsAddVectorsOutOfPlace) {
217+
Embedding E1 = {1.0, 2.0};
218+
Embedding E2 = {1.0};
219+
EXPECT_DEATH(E1 + E2, "Vectors must have the same dimension");
220+
}
221+
183222
TEST(EmbeddingTest, MismatchedDimensionsAddVectors) {
184223
Embedding E1 = {1.0, 2.0};
185224
Embedding E2 = {1.0};

0 commit comments

Comments
 (0)